diff --git a/ortools/base/BUILD.bazel b/ortools/base/BUILD.bazel index f623f480fd3..1439a16a8b5 100644 --- a/ortools/base/BUILD.bazel +++ b/ortools/base/BUILD.bazel @@ -113,15 +113,25 @@ cc_binary( srcs = ["constant_divisor_benchmark.cc"], deps = [ ":constant_divisor", + "//ortools/base:benchmark_main", "@abseil-cpp//absl/flags:flag", "@abseil-cpp//absl/random", "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/random:distributions", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) +cc_library( + name = "benchmark_main", + srcs = ["benchmark_main.cc"], + deps = [ + "//ortools/base", + "@google_benchmark//:benchmark", + ], + alwayslink = True, +) + cc_library( name = "container_logging", hdrs = ["container_logging.h"], diff --git a/ortools/base/CMakeLists.txt b/ortools/base/CMakeLists.txt index 861587487c3..932d60a952a 100644 --- a/ortools/base/CMakeLists.txt +++ b/ortools/base/CMakeLists.txt @@ -15,6 +15,7 @@ file(GLOB _SRCS "*.h" "*.cc") list(FILTER _SRCS EXCLUDE REGEX ".*/.*_benchmark.cc") list(FILTER _SRCS EXCLUDE REGEX ".*/.*_test.cc") list(FILTER _SRCS EXCLUDE REGEX "/gmock\.h") +list(FILTER _SRCS EXCLUDE REGEX ".*/benchmark_main\.cc") set(NAME ${PROJECT_NAME}_base) diff --git a/ortools/base/benchmark_main.cc b/ortools/base/benchmark_main.cc new file mode 100644 index 00000000000..13df25c5e22 --- /dev/null +++ b/ortools/base/benchmark_main.cc @@ -0,0 +1,26 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "benchmark/benchmark.h" +#include "ortools/base/init_google.h" + +int main(int argc, char* argv[]) { + benchmark::MaybeReenterWithoutASLR(argc, argv); + benchmark::Initialize(&argc, argv); + InitGoogle(argv[0], &argc, &argv, false); + benchmark::RunSpecifiedBenchmarks(); + benchmark::Shutdown(); + return EXIT_SUCCESS; +} diff --git a/ortools/glop/BUILD.bazel b/ortools/glop/BUILD.bazel index 52faa373978..17d8851c69e 100644 --- a/ortools/glop/BUILD.bazel +++ b/ortools/glop/BUILD.bazel @@ -245,6 +245,7 @@ cc_library( "//ortools/util:time_limit", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/flags:flag", + "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:die_if_null", "@abseil-cpp//absl/log:vlog_is_on", @@ -338,6 +339,7 @@ cc_library( "//ortools/util:logging", "//ortools/util:testing_utils", "//ortools/util:time_limit", + "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/flags:flag", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", diff --git a/ortools/glop/entering_variable.cc b/ortools/glop/entering_variable.cc index 594a51c4b60..58bbcc4aa60 100644 --- a/ortools/glop/entering_variable.cc +++ b/ortools/glop/entering_variable.cc @@ -40,11 +40,9 @@ EnteringVariable::EnteringVariable(const VariablesInfo& variables_info, reduced_costs_(reduced_costs), parameters_() {} -Status EnteringVariable::DualChooseEnteringColumn( +ColIndex EnteringVariable::DualChooseEnteringColumn( bool nothing_to_recompute, const UpdateRow& update_row, - Fractional cost_variation, std::vector* bound_flip_candidates, - ColIndex* entering_col) { - GLOP_RETURN_ERROR_IF_NULL(entering_col); + Fractional cost_variation, std::vector* bound_flip_candidates) { const auto update_coefficients = update_row.GetCoefficients().const_view(); const auto reduced_costs = reduced_costs_->GetReducedCosts(); SCOPED_TIME_STAT(&stats_); @@ -145,7 +143,7 @@ Status EnteringVariable::DualChooseEnteringColumn( // - We have processed all breakpoints with a ratio smaller than it. harris_ratio = std::numeric_limits::max(); - *entering_col = kInvalidCol; + ColIndex entering_col = kInvalidCol; bound_flip_candidates->clear(); Fractional step = 0.0; Fractional best_coeff = -1.0; @@ -196,12 +194,12 @@ Status EnteringVariable::DualChooseEnteringColumn( top.ratio + harris_tolerance / top.coeff_magnitude)); if (top.coeff_magnitude == best_coeff && top.ratio == step) { - DCHECK_NE(*entering_col, kInvalidCol); + DCHECK_NE(entering_col, kInvalidCol); equivalent_entering_choices_.push_back(top.col); } else { equivalent_entering_choices_.clear(); best_coeff = top.coeff_magnitude; - *entering_col = top.col; + entering_col = top.col; // Note that the step is not directly used, so it is okay to leave it // negative. @@ -217,15 +215,15 @@ Status EnteringVariable::DualChooseEnteringColumn( // Break the ties randomly. if (!equivalent_entering_choices_.empty()) { - equivalent_entering_choices_.push_back(*entering_col); - *entering_col = + equivalent_entering_choices_.push_back(entering_col); + entering_col = equivalent_entering_choices_[std::uniform_int_distribution( 0, equivalent_entering_choices_.size() - 1)(random_)]; IF_STATS_ENABLED( stats_.num_perfect_ties.Add(equivalent_entering_choices_.size())); } - if (*entering_col == kInvalidCol) return Status::OK(); + if (entering_col == kInvalidCol) return entering_col; // If best_coeff is small and they are potential bound flips, we can take a // smaller step but use a good pivot. @@ -239,18 +237,17 @@ Status EnteringVariable::DualChooseEnteringColumn( VLOG(1) << "Used bound flip to avoid bad pivot. Before: " << best_coeff << " now: " << std::abs(update_coefficients[col]); - *entering_col = col; + entering_col = col; break; } } - return Status::OK(); + return entering_col; } -Status EnteringVariable::DualPhaseIChooseEnteringColumn( +ColIndex EnteringVariable::DualPhaseIChooseEnteringColumn( bool nothing_to_recompute, const UpdateRow& update_row, - Fractional cost_variation, ColIndex* entering_col) { - GLOP_RETURN_ERROR_IF_NULL(entering_col); + Fractional cost_variation) { const auto update_coefficients = update_row.GetCoefficients().const_view(); const auto reduced_costs = reduced_costs_->GetReducedCosts(); SCOPED_TIME_STAT(&stats_); @@ -334,7 +331,7 @@ Status EnteringVariable::DualPhaseIChooseEnteringColumn( // Select the last breakpoint that still improves the infeasibility and has a // numerically stable pivot. - *entering_col = kInvalidCol; + ColIndex entering_col = kInvalidCol; Fractional step = -1.0; Fractional improvement = std::abs(cost_variation); while (!breakpoints_.empty()) { @@ -344,7 +341,7 @@ Status EnteringVariable::DualPhaseIChooseEnteringColumn( DCHECK(top.ratio > step || (top.ratio == step && top.coeff_magnitude <= pivot_magnitude)); if (top.ratio > step && top.coeff_magnitude >= pivot_magnitude) { - *entering_col = top.col; + entering_col = top.col; step = top.ratio; pivot_magnitude = top.coeff_magnitude; } @@ -362,7 +359,7 @@ Status EnteringVariable::DualPhaseIChooseEnteringColumn( std::pop_heap(breakpoints_.begin(), breakpoints_.end()); breakpoints_.pop_back(); } - return Status::OK(); + return entering_col; } void EnteringVariable::SetParameters(const GlopParameters& parameters) { diff --git a/ortools/glop/entering_variable.h b/ortools/glop/entering_variable.h index d7d499525b2..911aba4bb24 100644 --- a/ortools/glop/entering_variable.h +++ b/ortools/glop/entering_variable.h @@ -55,18 +55,17 @@ class EnteringVariable { // the "update" row vector in the direction given by the sign of // cost_variation. Computes the smallest step that keeps the dual feasibility // for all the columns. - ABSL_MUST_USE_RESULT Status DualChooseEnteringColumn( + ColIndex DualChooseEnteringColumn( bool nothing_to_recompute, const UpdateRow& update_row, - Fractional cost_variation, std::vector* bound_flip_candidates, - ColIndex* entering_col); + Fractional cost_variation, std::vector* bound_flip_candidates); // Dual feasibility phase (i.e. phase I) ratio test. // Similar to the optimization phase test, but allows a step that increases // the infeasibility of an already infeasible column. The step magnitude is // the one that minimize the sum of infeasibilities when applied. - ABSL_MUST_USE_RESULT Status DualPhaseIChooseEnteringColumn( - bool nothing_to_recompute, const UpdateRow& update_row, - Fractional cost_variation, ColIndex* entering_col); + ColIndex DualPhaseIChooseEnteringColumn(bool nothing_to_recompute, + const UpdateRow& update_row, + Fractional cost_variation); // Sets the parameters. void SetParameters(const GlopParameters& parameters); diff --git a/ortools/glop/lp_solver.cc b/ortools/glop/lp_solver.cc index 3a882b8b50c..47d52bf939d 100644 --- a/ortools/glop/lp_solver.cc +++ b/ortools/glop/lp_solver.cc @@ -152,15 +152,11 @@ SolverLogger& LPSolver::GetSolverLogger() { return logger_; } ProblemStatus LPSolver::Solve(const LinearProgram& lp) { std::unique_ptr time_limit = TimeLimit::FromParameters(parameters_); - return SolveWithTimeLimit(lp, time_limit.get()); + return SolveWithTimeLimit(lp, *time_limit); } ProblemStatus LPSolver::SolveWithTimeLimit(const LinearProgram& lp, - TimeLimit* time_limit) { - if (time_limit == nullptr) { - LOG(DFATAL) << "SolveWithTimeLimit() called with a nullptr time_limit."; - return ProblemStatus::ABNORMAL; - } + TimeLimit& time_limit) { ++num_solves_; num_revised_simplex_iterations_ = 0; DumpLinearProgramIfRequiredByFlags(lp, num_solves_); @@ -227,7 +223,7 @@ ProblemStatus LPSolver::SolveWithTimeLimit(const LinearProgram& lp, // Preprocess. MainLpPreprocessor preprocessor(¶meters_); preprocessor.SetLogger(&logger_); - preprocessor.SetTimeLimit(time_limit); + preprocessor.SetTimeLimit(&time_limit); const bool postsolve_is_needed = preprocessor.Run(¤t_linear_program_); @@ -261,7 +257,7 @@ ProblemStatus LPSolver::SolveWithTimeLimit(const LinearProgram& lp, // Do not launch the solver if the time limit was already reached. This might // mean that the pre-processors were not all run, and current_linear_program_ // might not be in a completely safe state. - if (!time_limit->LimitReached()) { + if (!time_limit.LimitReached()) { RunRevisedSimplexIfNeeded(&solution, time_limit); } if (postsolve_is_needed) preprocessor.DestructiveRecoverSolution(&solution); @@ -271,15 +267,24 @@ ProblemStatus LPSolver::SolveWithTimeLimit(const LinearProgram& lp, SOLVER_LOG(&logger_, "status: ", GetProblemStatusString(status)); SOLVER_LOG(&logger_, "objective: ", GetObjectiveValue()); SOLVER_LOG(&logger_, "iterations: ", GetNumberOfSimplexIterations()); - SOLVER_LOG(&logger_, "time: ", time_limit->GetElapsedTime()); + SOLVER_LOG(&logger_, "time: ", time_limit.GetElapsedTime()); SOLVER_LOG(&logger_, "deterministic_time: ", - time_limit->GetElapsedDeterministicTime()); + time_limit.GetElapsedDeterministicTime()); SOLVER_LOG(&logger_, ""); } return status; } +ProblemStatus LPSolver::SolveWithTimeLimit(const LinearProgram& lp, + TimeLimit* time_limit) { + if (time_limit == nullptr) { + LOG(DFATAL) << "SolveWithTimeLimit() called with a nullptr time_limit."; + return ProblemStatus::ABNORMAL; + } + return SolveWithTimeLimit(lp, *time_limit); +} + void LPSolver::Clear() { ResizeSolution(RowIndex(0), ColIndex(0)); revised_simplex_.reset(nullptr); @@ -609,7 +614,7 @@ void LPSolver::ResizeSolution(RowIndex num_rows, ColIndex num_cols) { } void LPSolver::RunRevisedSimplexIfNeeded(ProblemSolution* solution, - TimeLimit* time_limit) { + TimeLimit& time_limit) { // Note that the transpose matrix is no longer needed at this point. // This helps reduce the peak memory usage of the solver. // diff --git a/ortools/glop/lp_solver.h b/ortools/glop/lp_solver.h index fb40fbb559b..efe1c7b5144 100644 --- a/ortools/glop/lp_solver.h +++ b/ortools/glop/lp_solver.h @@ -17,6 +17,7 @@ #include #include +#include "absl/base/attributes.h" #include "ortools/glop/parameters.pb.h" #include "ortools/glop/revised_simplex.h" #include "ortools/lp_data/lp_data.h" @@ -63,6 +64,13 @@ class LPSolver { // Same as Solve() but use the given time limit rather than constructing a new // one from the current GlopParameters. + ABSL_MUST_USE_RESULT ProblemStatus SolveWithTimeLimit(const LinearProgram& lp, + TimeLimit& time_limit); + + // Legacy version of SolveWithTimeLimit() passing a pointer on TimeLimit and + // expecting it to be non-null (it returns ProblemStatus::ABNORMAL and use + // LOG(DFATAL) when null). + ABSL_DEPRECATED("Use SolveWithTimeLimit(const LinearProgram&, TimeLimit&)."); ABSL_MUST_USE_RESULT ProblemStatus SolveWithTimeLimit(const LinearProgram& lp, TimeLimit* time_limit); @@ -193,7 +201,7 @@ class LPSolver { // Runs the revised simplex algorithm if needed (i.e. if the program was not // already solved by the preprocessors). void RunRevisedSimplexIfNeeded(ProblemSolution* solution, - TimeLimit* time_limit); + TimeLimit& time_limit); // Checks that the returned solution values and statuses are consistent. // Returns true if this is the case. See the code for the exact check diff --git a/ortools/glop/revised_simplex.cc b/ortools/glop/revised_simplex.cc index 24010bad3db..514d0335d25 100644 --- a/ortools/glop/revised_simplex.cc +++ b/ortools/glop/revised_simplex.cc @@ -18,12 +18,14 @@ #include #include #include +#include #include #include #include #include "absl/flags/flag.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/log/vlog_is_on.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" @@ -156,8 +158,8 @@ void RevisedSimplex::SetStartingVariableValuesForNextSolve( Status RevisedSimplex::MinimizeFromTransposedMatrixWithSlack( const DenseRow& objective, Fractional objective_scaling_factor, - Fractional objective_offset, TimeLimit* time_limit) { - const double start_time = time_limit->GetElapsedTime(); + Fractional objective_offset, TimeLimit& time_limit) { + const double start_time = time_limit.GetElapsedTime(); default_logger_.EnableLogging(parameters_.log_search_progress()); default_logger_.SetLogToStdOut(parameters_.log_to_stdout()); parameters_ = initial_parameters_; @@ -201,8 +203,16 @@ Status RevisedSimplex::MinimizeFromTransposedMatrixWithSlack( return SolveInternal(start_time, false, objective, time_limit); } -Status RevisedSimplex::Solve(const LinearProgram& lp, TimeLimit* time_limit) { - const double start_time = time_limit->GetElapsedTime(); +Status RevisedSimplex::MinimizeFromTransposedMatrixWithSlack( + const DenseRow& objective, Fractional objective_scaling_factor, + Fractional objective_offset, TimeLimit* time_limit) { + GLOP_RETURN_ERROR_IF_NULL(time_limit); + return MinimizeFromTransposedMatrixWithSlack( + objective, objective_scaling_factor, objective_offset, *time_limit); +} + +Status RevisedSimplex::Solve(const LinearProgram& lp, TimeLimit& time_limit) { + const double start_time = time_limit.GetElapsedTime(); default_logger_.EnableLogging(parameters_.log_search_progress()); default_logger_.SetLogToStdOut(parameters_.log_to_stdout()); @@ -212,17 +222,21 @@ Status RevisedSimplex::Solve(const LinearProgram& lp, TimeLimit* time_limit) { lp.objective_coefficients(), time_limit); } +Status RevisedSimplex::Solve(const LinearProgram& lp, TimeLimit* time_limit) { + GLOP_RETURN_ERROR_IF_NULL(time_limit); + return Solve(lp, *time_limit); +} + ABSL_MUST_USE_RESULT Status RevisedSimplex::SolveInternal( double start_time, bool is_maximization_problem, - const DenseRow& objective_coefficients, TimeLimit* time_limit) { + const DenseRow& objective_coefficients, TimeLimit& time_limit) { SCOPED_TIME_STAT(&function_stats_); - GLOP_RETURN_ERROR_IF_NULL(time_limit); Cleanup update_deterministic_time_on_return( - [this, time_limit]() { AdvanceDeterministicTime(time_limit); }); + [this, &time_limit]() { AdvanceDeterministicTime(time_limit); }); SOLVER_LOG(logger_, ""); - primal_edge_norms_.SetTimeLimit(time_limit); - dual_edge_norms_.SetTimeLimit(time_limit); + primal_edge_norms_.SetTimeLimit(&time_limit); + dual_edge_norms_.SetTimeLimit(&time_limit); if (logger_->LoggingIsEnabled()) { DisplayBasicVariableStatistics(); @@ -371,7 +385,7 @@ ABSL_MUST_USE_RESULT Status RevisedSimplex::SolveInternal( DisplayErrors(); phase_ = Phase::OPTIMIZATION; - feasibility_time_ = time_limit->GetElapsedTime() - start_time; + feasibility_time_ = time_limit.GetElapsedTime() - start_time; primal_edge_norms_.SetPricingRule(parameters_.optimization_rule()); num_feasibility_iterations_ = num_iterations_; @@ -396,7 +410,7 @@ ABSL_MUST_USE_RESULT Status RevisedSimplex::SolveInternal( !objective_limit_reached_ && (num_iterations_ == 0 || num_iterations_ < parameters_.max_number_of_iterations()) && - !time_limit->LimitReached() && + !time_limit.LimitReached() && !absl::GetFlag(FLAGS_simplex_stop_after_feasibility) && (problem_status_ == ProblemStatus::PRIMAL_FEASIBLE || problem_status_ == ProblemStatus::DUAL_FEASIBLE); @@ -650,7 +664,7 @@ ABSL_MUST_USE_RESULT Status RevisedSimplex::SolveInternal( } } - total_time_ = time_limit->GetElapsedTime() - start_time; + total_time_ = time_limit.GetElapsedTime() - start_time; optimization_time_ = total_time_ - feasibility_time_; num_optimization_iterations_ = num_iterations_ - num_feasibility_iterations_; @@ -688,7 +702,7 @@ ABSL_MUST_USE_RESULT Status RevisedSimplex::SolveInternal( } } - total_time_ = time_limit->GetElapsedTime() - start_time; + total_time_ = time_limit.GetElapsedTime() - start_time; push_time_ = total_time_ - feasibility_time_ - optimization_time_; num_push_iterations_ = num_iterations_ - num_feasibility_iterations_ - num_optimization_iterations_; @@ -2715,10 +2729,9 @@ int RevisedSimplex::NumNonIntegerInBasis() const { return num_non_integer; } -Status RevisedSimplex::PrimalPolish(TimeLimit* time_limit) { - GLOP_RETURN_ERROR_IF_NULL(time_limit); +Status RevisedSimplex::PrimalPolish(TimeLimit& time_limit) { Cleanup update_deterministic_time_on_return( - [this, time_limit]() { AdvanceDeterministicTime(time_limit); }); + [this, &time_limit]() { AdvanceDeterministicTime(time_limit); }); // Get all non-basic variables with a reduced costs close to zero. // Note that because we only choose entering candidate with a cost of zero, @@ -2739,7 +2752,7 @@ Status RevisedSimplex::PrimalPolish(TimeLimit* time_limit) { Fractional total_gain = 0.0; for (int i = 0; i < 1'000; ++i) { AdvanceDeterministicTime(time_limit); - if (time_limit->LimitReached()) break; + if (time_limit.LimitReached()) break; if (num_pivots >= 100) break; if (candidates.empty()) break; @@ -2868,10 +2881,9 @@ void RevisedSimplex::FillWithNonIntegerInBasis( } } -Status RevisedSimplex::DualPolish(TimeLimit* time_limit) { - GLOP_RETURN_ERROR_IF_NULL(time_limit); +Status RevisedSimplex::DualPolish(TimeLimit& time_limit) { Cleanup update_deterministic_time_on_return( - [this, time_limit]() { AdvanceDeterministicTime(time_limit); }); + [this, &time_limit]() { AdvanceDeterministicTime(time_limit); }); int num_pivots = 0; Fractional total_gain = 0.0; @@ -2920,11 +2932,10 @@ Status RevisedSimplex::DualPolish(TimeLimit* time_limit) { // TODO(user): just look at 2/3 position with higest coeff and reduced // cost close to zero ? - ColIndex entering_col = kInvalidCol; bound_flip_candidates_.clear(); - GLOP_RETURN_IF_ERROR(entering_variable_.DualChooseEnteringColumn( + const ColIndex entering_col = entering_variable_.DualChooseEnteringColumn( reduced_costs_.AreReducedCostsPrecise(), update_row_, cost_variation, - &bound_flip_candidates_, &entering_col)); + &bound_flip_candidates_); if (entering_col == kInvalidCol) continue; // When we are at optimal, only moves with a reduced cost of zero should @@ -3040,10 +3051,9 @@ Status RevisedSimplex::DualPolish(TimeLimit* time_limit) { // enter the basis, and a variable from x_B is selected to leave the basis. // To avoid explicit inversion of B, the algorithm solves two sub-systems: // y.B = c_B and B.d = a (a being the entering column). -Status RevisedSimplex::PrimalMinimize(TimeLimit* time_limit) { - GLOP_RETURN_ERROR_IF_NULL(time_limit); +Status RevisedSimplex::PrimalMinimize(TimeLimit& time_limit) { Cleanup update_deterministic_time_on_return( - [this, time_limit]() { AdvanceDeterministicTime(time_limit); }); + [this, &time_limit]() { AdvanceDeterministicTime(time_limit); }); num_consecutive_degenerate_iterations_ = 0; bool refactorize = false; last_refactorization_reason_ = RefactorizationReason::DEFAULT; @@ -3063,7 +3073,7 @@ Status RevisedSimplex::PrimalMinimize(TimeLimit* time_limit) { while (true) { AdvanceDeterministicTime(time_limit); - if (time_limit->LimitReached()) break; + if (time_limit.LimitReached()) break; // TODO(user): we may loop a bit more than the actual number of iteration. // fix. @@ -3346,9 +3356,9 @@ Status RevisedSimplex::PrimalMinimize(TimeLimit* time_limit) { // // Note that the returned status applies to the primal problem! Status RevisedSimplex::DualMinimize(bool feasibility_phase, - TimeLimit* time_limit) { + TimeLimit& time_limit) { Cleanup update_deterministic_time_on_return( - [this, time_limit]() { AdvanceDeterministicTime(time_limit); }); + [this, &time_limit]() { AdvanceDeterministicTime(time_limit); }); num_consecutive_degenerate_iterations_ = 0; bool refactorize = false; last_refactorization_reason_ = RefactorizationReason::DEFAULT; @@ -3360,12 +3370,9 @@ Status RevisedSimplex::DualMinimize(bool feasibility_phase, Fractional cost_variation; Fractional target_bound; - // Entering variable. - ColIndex entering_col; - while (true) { AdvanceDeterministicTime(time_limit); - if (time_limit->LimitReached()) break; + if (time_limit.LimitReached()) break; // TODO(user): we may loop a bit more than the actual number of iteration. // fix. @@ -3515,15 +3522,14 @@ Status RevisedSimplex::DualMinimize(bool feasibility_phase, } update_row_.ComputeUpdateRow(leaving_row); - if (feasibility_phase) { - GLOP_RETURN_IF_ERROR(entering_variable_.DualPhaseIChooseEnteringColumn( - reduced_costs_.AreReducedCostsPrecise(), update_row_, cost_variation, - &entering_col)); - } else { - GLOP_RETURN_IF_ERROR(entering_variable_.DualChooseEnteringColumn( - reduced_costs_.AreReducedCostsPrecise(), update_row_, cost_variation, - &bound_flip_candidates_, &entering_col)); - } + const ColIndex entering_col = + feasibility_phase + ? entering_variable_.DualPhaseIChooseEnteringColumn( + reduced_costs_.AreReducedCostsPrecise(), update_row_, + cost_variation) + : entering_variable_.DualChooseEnteringColumn( + reduced_costs_.AreReducedCostsPrecise(), update_row_, + cost_variation, &bound_flip_candidates_); // No entering_col: dual unbounded (i.e. primal infeasible). if (entering_col == kInvalidCol) { @@ -3667,10 +3673,9 @@ Status RevisedSimplex::DualMinimize(bool feasibility_phase, return Status::OK(); } -Status RevisedSimplex::PrimalPush(TimeLimit* time_limit) { - GLOP_RETURN_ERROR_IF_NULL(time_limit); +Status RevisedSimplex::PrimalPush(TimeLimit& time_limit) { Cleanup update_deterministic_time_on_return( - [this, time_limit]() { AdvanceDeterministicTime(time_limit); }); + [this, &time_limit]() { AdvanceDeterministicTime(time_limit); }); bool refactorize = false; // We clear all the quantities that we don't update so they will be recomputed @@ -3690,7 +3695,7 @@ Status RevisedSimplex::PrimalPush(TimeLimit* time_limit) { while (!super_basic_cols.empty()) { AdvanceDeterministicTime(time_limit); - if (time_limit->LimitReached()) break; + if (time_limit.LimitReached()) break; ScopedTimeDistributionUpdater timer(&iteration_stats_.total); GLOP_RETURN_IF_ERROR(RefactorizeBasisIfNeeded(&refactorize)); @@ -4202,12 +4207,11 @@ void RevisedSimplex::DisplayProblem() { } } -void RevisedSimplex::AdvanceDeterministicTime(TimeLimit* time_limit) { - DCHECK(time_limit != nullptr); +void RevisedSimplex::AdvanceDeterministicTime(TimeLimit& time_limit) { const double current_deterministic_time = DeterministicTime(); const double deterministic_time_delta = current_deterministic_time - last_deterministic_time_update_; - time_limit->AdvanceDeterministicTime(deterministic_time_delta); + time_limit.AdvanceDeterministicTime(deterministic_time_delta); last_deterministic_time_update_ = current_deterministic_time; } diff --git a/ortools/glop/revised_simplex.h b/ortools/glop/revised_simplex.h index 357ecaa2c4a..b117c06eb2c 100644 --- a/ortools/glop/revised_simplex.h +++ b/ortools/glop/revised_simplex.h @@ -154,6 +154,12 @@ class RevisedSimplex { // and try to use the previously computed solution as a warm-start. To disable // this behavior or give explicit warm-start data, use one of the State*() // functions below. + ABSL_MUST_USE_RESULT Status Solve(const LinearProgram& lp, + TimeLimit& time_limit); + + // Legacy version of Solve() passing a pointer on TimeLimit and expecting it + // to be non-null (it returns Status::ERROR_NULL when null). + ABSL_DEPRECATED("Use Solve(const LinearProgram&, TimeLimit&) instead."); ABSL_MUST_USE_RESULT Status Solve(const LinearProgram& lp, TimeLimit* time_limit); @@ -266,6 +272,15 @@ class RevisedSimplex { DenseRow* MutableUpperBounds() { return variables_info_.MutableUpperBounds(); } + ABSL_MUST_USE_RESULT Status MinimizeFromTransposedMatrixWithSlack( + const DenseRow& objective, Fractional objective_scaling_factor, + Fractional objective_offset, TimeLimit& time_limit); + + // Legacy version of MinimizeFromTransposedMatrixWithSlack() passing a pointer + // on TimeLimit and expecting it to be non-null (it returns + // Status::ERROR_NULL). + ABSL_DEPRECATED( + "Use MinimizeFromTransposedMatrixWithSlack(..., TimeLimit&) instead."); ABSL_MUST_USE_RESULT Status MinimizeFromTransposedMatrixWithSlack( const DenseRow& objective, Fractional objective_scaling_factor, Fractional objective_offset, TimeLimit* time_limit); @@ -323,7 +338,7 @@ class RevisedSimplex { ABSL_MUST_USE_RESULT Status SolveInternal(double start_time, bool maximize, const DenseRow& objective, - TimeLimit* time_limit); + TimeLimit& time_limit); // Propagates parameters_ to all the other classes that need it. // @@ -625,18 +640,18 @@ class RevisedSimplex { Status RefactorizeBasisIfNeeded(bool* refactorize); // Main iteration loop of the primal simplex. - ABSL_MUST_USE_RESULT Status PrimalMinimize(TimeLimit* time_limit); + ABSL_MUST_USE_RESULT Status PrimalMinimize(TimeLimit& time_limit); // Main iteration loop of the dual simplex. ABSL_MUST_USE_RESULT Status DualMinimize(bool feasibility_phase, - TimeLimit* time_limit); + TimeLimit& time_limit); // Pushes all super-basic variables to bounds (if applicable) or to zero (if // unconstrained). This is part of a "crossover" procedure to find a vertex // solution given a (near) optimal solution. Assumes that Minimize() or // DualMinimize() has already run, i.e., that we are at an optimal solution // within numerical tolerances. - ABSL_MUST_USE_RESULT Status PrimalPush(TimeLimit* time_limit); + ABSL_MUST_USE_RESULT Status PrimalPush(TimeLimit& time_limit); // Experimental. This is useful in a MIP context. It performs a few degenerate // pivot to try to mimize the fractionality of the optimal basis. @@ -646,8 +661,8 @@ class RevisedSimplex { // // I could only find slides for the reference of this "LP Solution Polishing // to improve MIP Performance", Matthias Miltenberger, Zuse Institute Berlin. - ABSL_MUST_USE_RESULT Status PrimalPolish(TimeLimit* time_limit); - ABSL_MUST_USE_RESULT Status DualPolish(TimeLimit* time_limit); + ABSL_MUST_USE_RESULT Status PrimalPolish(TimeLimit& time_limit); + ABSL_MUST_USE_RESULT Status DualPolish(TimeLimit& time_limit); // Helper function for Primal/DualPolish(). Fractional IntegralityChange(ColIndex col, Fractional old_value, @@ -665,7 +680,7 @@ class RevisedSimplex { // during the last call to this method. // TODO(user): Update the internals of revised simplex so that the time // limit is updated at the source and remove this method. - void AdvanceDeterministicTime(TimeLimit* time_limit); + void AdvanceDeterministicTime(TimeLimit& time_limit); // Problem status ProblemStatus problem_status_; diff --git a/ortools/graph/BUILD.bazel b/ortools/graph/BUILD.bazel index 9182c029855..61d8cd92b17 100644 --- a/ortools/graph/BUILD.bazel +++ b/ortools/graph/BUILD.bazel @@ -29,6 +29,7 @@ cc_library( "//ortools/base:threadpool", "//ortools/base:top_n", "//ortools/graph_base:graph", + "//ortools/util:bitset", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/log:check", @@ -59,13 +60,13 @@ cc_binary( srcs = ["bounded_dijkstra_benchmark.cc"], deps = [ ":bounded_dijkstra", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "//ortools/graph_base:test_util", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random", "@abseil-cpp//absl/random:distributions", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -170,6 +171,7 @@ cc_binary( srcs = ["cliques_benchmark.cc"], deps = [ ":cliques", + "//ortools/base:benchmark_main", "//ortools/util:time_limit", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/functional:bind_front", @@ -178,7 +180,6 @@ cc_binary( "@abseil-cpp//absl/random:distributions", "@abseil-cpp//absl/types:span", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -251,9 +252,9 @@ cc_binary( srcs = ["christofides_benchmark.cc"], deps = [ ":christofides", + "//ortools/base:benchmark_main", "@abseil-cpp//absl/log:check", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -280,10 +281,10 @@ cc_binary( srcs = ["eulerian_path_benchmark.cc"], deps = [ ":eulerian_path", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "@abseil-cpp//absl/log:check", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -316,11 +317,11 @@ cc_binary( srcs = ["minimum_spanning_tree_benchmark.cc"], deps = [ ":minimum_spanning_tree", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random:distributions", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -423,11 +424,11 @@ cc_binary( deps = [ ":k_shortest_paths", ":shortest_paths", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random:distributions", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -507,6 +508,7 @@ cc_binary( srcs = ["generic_max_flow_benchmark.cc"], deps = [ ":generic_max_flow", + "//ortools/base:benchmark_main", "//ortools/graph_base:flow_graph", "//ortools/graph_base:graph", "@abseil-cpp//absl/log:check", @@ -514,7 +516,6 @@ cc_binary( "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/types:span", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -570,12 +571,12 @@ cc_binary( srcs = ["min_cost_flow_benchmark.cc"], deps = [ ":min_cost_flow", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random:distributions", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -635,10 +636,10 @@ cc_binary( srcs = ["assignment_benchmark.cc"], deps = [ ":assignment", + "//ortools/base:benchmark_main", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random:distributions", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -678,11 +679,11 @@ cc_binary( srcs = ["linear_assignment_benchmark.cc"], deps = [ ":linear_assignment", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random:distributions", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -792,6 +793,7 @@ cc_binary( srcs = ["dag_constrained_shortest_path_benchmark.cc"], deps = [ ":dag_constrained_shortest_path", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log:check", @@ -799,7 +801,6 @@ cc_binary( "@abseil-cpp//absl/random:distributions", "@abseil-cpp//absl/types:span", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -836,11 +837,11 @@ cc_binary( srcs = ["rooted_tree_benchmark.cc"], deps = [ ":rooted_tree", + "//ortools/base:benchmark_main", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -870,9 +871,9 @@ cc_binary( srcs = ["minimum_vertex_cover_benchmark.cc"], deps = [ ":minimum_vertex_cover", + "//ortools/base:benchmark_main", "@abseil-cpp//absl/algorithm:container", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) @@ -919,12 +920,12 @@ cc_binary( srcs = ["dag_shortest_path_benchmark.cc"], deps = [ ":dag_shortest_path", + "//ortools/base:benchmark_main", "//ortools/graph_base:graph", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random", "@abseil-cpp//absl/types:span", "@google_benchmark//:benchmark", - "@google_benchmark//:benchmark_main", ], ) diff --git a/ortools/graph/bounded_dijkstra.h b/ortools/graph/bounded_dijkstra.h index 0824dde18e8..8df20e4739e 100644 --- a/ortools/graph/bounded_dijkstra.h +++ b/ortools/graph/bounded_dijkstra.h @@ -22,7 +22,6 @@ #include #include -#include "absl/algorithm/container.h" #include "absl/base/attributes.h" #include "absl/log/check.h" #include "absl/types/span.h" @@ -30,6 +29,7 @@ #include "ortools/base/strong_vector.h" #include "ortools/base/top_n.h" #include "ortools/graph_base/graph.h" +#include "ortools/util/bitset.h" namespace operations_research { @@ -293,7 +293,7 @@ class BoundedDijkstraWrapper { ByNode distances_; ByNode parents_; ByNode arc_from_source_; - ByNode is_reached_; + Bitset64 is_reached_; std::vector reached_nodes_; // Priority queue of nodes, ordered by their distance to the source. @@ -332,6 +332,21 @@ class BoundedDijkstraWrapper { ByNode node_to_destination_index_; }; +// CTAD guides. +template +BoundedDijkstraWrapper(const GraphType* graph, + const DistanceContainer* arc_lengths) + -> BoundedDijkstraWrapper; + +template +BoundedDijkstraWrapper(const GraphType* graph, + ArcLengthFunctor arc_length_functor) + -> BoundedDijkstraWrapper< + GraphType, + std::invoke_result_t, + ArcLengthFunctor>; + // ----------------------------------------------------------------------------- // Implementation. // ----------------------------------------------------------------------------- @@ -494,16 +509,17 @@ BoundedDijkstraWrapper:: DistanceType distance_limit) { // Sparse clear is_reached_ from the last call. for (const NodeIndex node : reached_nodes_) { - is_reached_[node] = false; + is_reached_.Clear(node); } reached_nodes_.clear(); - DCHECK(!absl::c_linear_search(is_reached_, true)); - is_reached_.resize(graph_->num_nodes(), false); + is_reached_.Resize(graph_->num_nodes()); distances_.resize(graph_->num_nodes(), distance_limit); parents_.resize(graph_->num_nodes(), std::numeric_limits::min()); arc_from_source_.resize(graph_->num_nodes(), GraphType::kNilArc); + typename Bitset64::View is_reached = is_reached_.view(); + // Initialize sources. CHECK(queue_.empty()); node_to_source_index_.resize(graph_->num_nodes(), -1); @@ -515,9 +531,9 @@ BoundedDijkstraWrapper:: // Sources with an initial distance ≥ limit are *not* reached. if (distance >= distance_limit) continue; // Skip useless repetitions. - if (is_reached_[node] && distance >= distances_[node]) continue; - if (!is_reached_[node]) { - is_reached_[node] = true; + if (is_reached[node] && distance >= distances_[node]) continue; + if (!is_reached[node]) { + is_reached.Set(node); reached_nodes_.push_back(node); parents_[node] = node; // Set the parent to itself. } @@ -571,10 +587,10 @@ BoundedDijkstraWrapper:: const DistanceType candidate_distance = top.distance + arc_length; const NodeIndex head = graph_->Head(arc); - if (is_reached_[head]) { + if (is_reached[head]) { if (candidate_distance >= distances_[head]) continue; } else { - is_reached_[head] = true; + is_reached.Set(head); reached_nodes_.push_back(head); } distances_[head] = candidate_distance; diff --git a/ortools/graph/bounded_dijkstra_benchmark.cc b/ortools/graph/bounded_dijkstra_benchmark.cc index 64947285195..3849e62c04f 100644 --- a/ortools/graph/bounded_dijkstra_benchmark.cc +++ b/ortools/graph/bounded_dijkstra_benchmark.cc @@ -12,6 +12,7 @@ // limitations under the License. #include +#include #include #include #include @@ -27,35 +28,27 @@ namespace operations_research { namespace { -template -void BM_GridGraph(benchmark::State& state) { - typedef util::StaticGraph Graph; - const int64_t kWidth = 100; - const int64_t kHeight = 10000; - const int kSourceNode = static_cast(kWidth * kHeight / 2); - std::unique_ptr graph = - util::Create2DGridGraph(/*width=*/kWidth, /*height=*/kHeight); - BoundedDijkstraWrapper::ByArc arc_lengths( - graph->num_arcs(), 0); - const int64_t min_length = arc_lengths_are_discrete ? 0 : 1; - const int64_t max_length = arc_lengths_are_discrete ? 2 : 1000000000000000L; - std::mt19937 random(12345); - for (int64_t& length : arc_lengths) { - length = absl::Uniform(random, min_length, max_length + 1); - } - BoundedDijkstraWrapper dijkstra(graph.get(), &arc_lengths); - const int64_t kSearchRadius = kWidth * (min_length + max_length) / 2; +constexpr int64_t kBenchmarkWidth = 100; +constexpr int64_t kBenchmarkHeight = 10000; +using Graph = util::StaticGraph; + +template +void RunBenchmark(benchmark::State& state, T& dijkstra, + const int64_t search_radius) { + const int kSourceNode = + static_cast(kBenchmarkWidth * kBenchmarkHeight / 2); // NOTE(user): The expected number of nodes visited is in ϴ(kWidth²), // since the search radius is ϴ(kWidth). The exact constant is hard to // derive mathematically as a function of the radius formula, so I measured // it empirically and it was close to 0.5, which seemed about right. - const int64_t kExpectedApproximateSearchSize = kWidth * kWidth / 2; + const int64_t kExpectedApproximateSearchSize = + kBenchmarkWidth * kBenchmarkWidth / 2; int64_t total_nodes_visited = 0; for (auto _ : state) { const int num_nodes_visited = dijkstra .RunBoundedDijkstra(/*source_node=*/kSourceNode, - /*distance_limit=*/kSearchRadius) + /*distance_limit=*/search_radius) .size(); // We verify that the Dijkstra search size is in the expected range, to // make sure that we're measuring what we want in this benchmark. @@ -66,8 +59,45 @@ void BM_GridGraph(benchmark::State& state) { state.SetItemsProcessed(total_nodes_visited); } +template +void BM_GridGraph(benchmark::State& state) { + std::unique_ptr graph = + util::Create2DGridGraph(kBenchmarkWidth, kBenchmarkHeight); + BoundedDijkstraWrapper::ByArc arc_lengths( + graph->num_arcs(), 0); + const int64_t min_length = arc_lengths_are_discrete ? 0 : 1; + const int64_t max_length = arc_lengths_are_discrete ? 2 : 1000000000000000L; + std::mt19937 random(12345); + for (int64_t& length : arc_lengths) { + length = absl::Uniform(random, min_length, max_length + 1); + } + const int64_t kSearchRadius = kBenchmarkWidth * (min_length + max_length) / 2; + BoundedDijkstraWrapper dijkstra(graph.get(), &arc_lengths); + RunBenchmark(state, dijkstra, kSearchRadius); +} + BENCHMARK(BM_GridGraph); BENCHMARK(BM_GridGraph); +void BM_GridGraph_ArcLengthFunctor(benchmark::State& state) { + std::unique_ptr graph = + util::Create2DGridGraph(kBenchmarkWidth, kBenchmarkHeight); + BoundedDijkstraWrapper dijkstra( + graph.get(), [](int node_id) { return 1 + (node_id % 3); }); + RunBenchmark(state, dijkstra, 2 * kBenchmarkWidth); +} +BENCHMARK(BM_GridGraph_ArcLengthFunctor); + +void BM_GridGraph_ArcLengthStdFunction(benchmark::State& state) { + std::unique_ptr graph = + util::Create2DGridGraph(kBenchmarkWidth, kBenchmarkHeight); + std::function get_arc_length = [](int node_id) { + return 1 + (node_id % 3); + }; + BoundedDijkstraWrapper dijkstra(graph.get(), get_arc_length); + RunBenchmark(state, dijkstra, 2 * kBenchmarkWidth); +} +BENCHMARK(BM_GridGraph_ArcLengthStdFunction); + } // namespace } // namespace operations_research diff --git a/ortools/graph/bounded_dijkstra_test.cc b/ortools/graph/bounded_dijkstra_test.cc index 06851ed2bed..2f7a6151f8f 100644 --- a/ortools/graph/bounded_dijkstra_test.cc +++ b/ortools/graph/bounded_dijkstra_test.cc @@ -63,8 +63,7 @@ TEST(BoundedDijkstraWrapperDeathTest, Accessors) { TEST(BoundedDijkstraWrapperDeathTest, WithArcLengthFunctor) { TestGraph graph; graph.AddArc(NodeIndex(1), NodeIndex(3)); - BoundedDijkstraWrapper> - dijkstra(&graph, [](ArcIndex) { return 2.34; }); + BoundedDijkstraWrapper dijkstra(&graph, [](ArcIndex) { return 2.34f; }); ASSERT_FLOAT_EQ(dijkstra.GetArcLength(ArcIndex(0)), 2.34f); } diff --git a/ortools/sat/2d_distances_propagator.cc b/ortools/sat/2d_distances_propagator.cc index 32411f4949a..c2a95aa1ced 100644 --- a/ortools/sat/2d_distances_propagator.cc +++ b/ortools/sat/2d_distances_propagator.cc @@ -21,6 +21,7 @@ #include #include "absl/algorithm/container.h" +#include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" @@ -38,6 +39,7 @@ #include "ortools/sat/sat_base.h" #include "ortools/sat/scheduling_helpers.h" #include "ortools/sat/synchronization.h" +#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { @@ -50,7 +52,8 @@ Precedences2DPropagator::Precedences2DPropagator( shared_stats_(model->GetOrCreate()), lin2_indices_(model->GetOrCreate()), trail_(model->GetOrCreate()), - integer_trail_(model->GetOrCreate()) { + integer_trail_(model->GetOrCreate()), + time_limit_(model->GetOrCreate()) { model->GetOrCreate()->SetPushAffineUbForBinaryRelation(); } @@ -178,6 +181,10 @@ bool Precedences2DPropagator::Propagate() { num_calls_++; + absl::Cleanup increase_dtime = [this, size = pair_data_.size()]() { + time_limit_->AdvanceDeterministicTime(size * 5e-9); + }; + for (const PairData& pair_data : pair_data_) { if (!absl::c_all_of(pair_data.pair_presence_literals, [this](const Literal& literal) { diff --git a/ortools/sat/2d_distances_propagator.h b/ortools/sat/2d_distances_propagator.h index f170f13144d..b1c921bd678 100644 --- a/ortools/sat/2d_distances_propagator.h +++ b/ortools/sat/2d_distances_propagator.h @@ -28,6 +28,7 @@ #include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/synchronization.h" +#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { @@ -85,6 +86,7 @@ class Precedences2DPropagator : public PropagatorInterface { Linear2Indices* lin2_indices_; Trail* trail_; IntegerTrail* integer_trail_; + TimeLimit* time_limit_; int last_helper_inprocessing_count_ = -1; int num_known_linear2_ = 0; diff --git a/ortools/sat/2d_orthogonal_packing.h b/ortools/sat/2d_orthogonal_packing.h index 3f22cc01364..0ef2f6d1a1d 100644 --- a/ortools/sat/2d_orthogonal_packing.h +++ b/ortools/sat/2d_orthogonal_packing.h @@ -23,6 +23,7 @@ #include "absl/random/bit_gen_ref.h" #include "absl/types/span.h" #include "ortools/base/constant_divisor.h" +#include "ortools/base/types.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/synchronization.h" #include "ortools/util/bitset.h" @@ -35,7 +36,7 @@ struct OrthogonalPackingOptions { bool use_dff_f0 = true; bool use_dff_f2 = true; int brute_force_threshold = 6; - int dff2_max_number_of_parameters_to_check = std::numeric_limits::max(); + int dff2_max_number_of_parameters_to_check = kint32max; }; class OrthogonalPackingResult { @@ -322,7 +323,7 @@ class RoundingDualFeasibleFunctionPowerOfTwo { DCHECK_GE(log2_k_, 0); DCHECK_LT(log2_k_, 63); DCHECK_LE(2 * (1 << log2_k), max_x_); - DCHECK_LE(max_x_, std::numeric_limits::max() / 2); + DCHECK_LE(max_x_, kint64max / 2); } IntegerValue operator()(IntegerValue x) const { diff --git a/ortools/sat/2d_orthogonal_packing_test.cc b/ortools/sat/2d_orthogonal_packing_test.cc index a9fdab2725b..54d4fb775bf 100644 --- a/ortools/sat/2d_orthogonal_packing_test.cc +++ b/ortools/sat/2d_orthogonal_packing_test.cc @@ -30,6 +30,7 @@ #include "gtest/gtest.h" #include "ortools/algorithms/binary_search.h" #include "ortools/base/gmock.h" +#include "ortools/base/types.h" #include "ortools/sat/2d_orthogonal_packing_testing.h" #include "ortools/sat/diffn_util.h" #include "ortools/sat/integer_base.h" @@ -125,8 +126,8 @@ TEST(DualFeasibilityFunctionTest, Dff2IsMaximal) { TEST(DualFeasibilityFunctionTest, DffPowerOfTwo) { absl::BitGen random; for (int k = 0; k < 61; k++) { - TestMaximalDff( - random, 100, std::numeric_limits::max() / 2, k); + TestMaximalDff(random, 100, + kint64max / 2, k); } } diff --git a/ortools/sat/2d_rectangle_presolve.cc b/ortools/sat/2d_rectangle_presolve.cc index 068ab219bb5..a32de4af720 100644 --- a/ortools/sat/2d_rectangle_presolve.cc +++ b/ortools/sat/2d_rectangle_presolve.cc @@ -15,11 +15,13 @@ #include #include +#include +#include #include -#include #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -32,7 +34,6 @@ #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/log_severity.h" -#include "ortools/base/stl_util.h" #include "ortools/graph/minimum_vertex_cover.h" #include "ortools/graph_base/strongly_connected_components.h" #include "ortools/sat/diffn_util.h" @@ -47,45 +48,14 @@ std::vector FindSpacesThatCannotBeOccupied( absl::Span non_fixed_boxes, const Rectangle& bounding_box, IntegerValue min_x_size, IntegerValue min_y_size) { - std::vector optional_boxes = {fixed_boxes.begin(), - fixed_boxes.end()}; - - if (bounding_box.x_min > std::numeric_limits::min() && - bounding_box.y_min > std::numeric_limits::min() && - bounding_box.x_max < std::numeric_limits::max() && - bounding_box.y_max < std::numeric_limits::max()) { - // Add fake rectangles to build a frame around the bounding box. This allows - // to find more areas that must be empty. The frame is as follows: - // +************ - // +...........+ - // +...........+ - // +...........+ - // ************+ - optional_boxes.push_back({.x_min = bounding_box.x_min - 1, - .x_max = bounding_box.x_max, - .y_min = bounding_box.y_min - 1, - .y_max = bounding_box.y_min}); - optional_boxes.push_back({.x_min = bounding_box.x_max, - .x_max = bounding_box.x_max + 1, - .y_min = bounding_box.y_min - 1, - .y_max = bounding_box.y_max}); - optional_boxes.push_back({.x_min = bounding_box.x_min, - .x_max = bounding_box.x_max + 1, - .y_min = bounding_box.y_max, - .y_max = bounding_box.y_max + 1}); - optional_boxes.push_back({.x_min = bounding_box.x_min - 1, - .x_max = bounding_box.x_min, - .y_min = bounding_box.y_min, - .y_max = bounding_box.y_max + 1}); - } - + std::vector optional_boxes(fixed_boxes.begin(), fixed_boxes.end()); // All items we added to `optional_boxes` at this point are only to be used by // the "gap between items" logic below. They are not actual optional boxes and // should be removed right after the logic is applied. const int num_optional_boxes_to_remove = optional_boxes.size(); - // Add a rectangle to `optional_boxes` but respecting that rectangles must - // remain disjoint. + // Add a rectangle to `optional_boxes` while respecting that the set of + // rectangles in `optional_boxes` must remain mutually disjoint. const auto add_box = [&optional_boxes](Rectangle new_box) { std::vector to_add = {std::move(new_box)}; for (int i = 0; i < to_add.size(); ++i) { @@ -110,7 +80,8 @@ std::vector FindSpacesThatCannotBeOccupied( // Now check if there is any space that cannot be occupied by any non-fixed // item. // TODO(user): remove the limit of 1000 and reimplement FindEmptySpaces() - // using a sweep line algorithm. + // using a sweep line algorithm. Note that supporting non-disjoint input is + // necessary here, so we cannot use FindEmptySpacesHorizontally. if (non_fixed_boxes.size() < 1000) { std::vector bounding_boxes; bounding_boxes.reserve(non_fixed_boxes.size()); @@ -124,40 +95,28 @@ std::vector FindSpacesThatCannotBeOccupied( } } - // Now look for gaps between objects that are too small to place anything. - for (int i = 1; i < optional_boxes.size(); ++i) { - const Rectangle cur_box = optional_boxes[i]; - for (int j = 0; j < i; ++j) { - const Rectangle& other_box = optional_boxes[j]; - const IntegerValue lower_top = std::min(cur_box.y_max, other_box.y_max); - const IntegerValue higher_bottom = - std::max(other_box.y_min, cur_box.y_min); - const IntegerValue rightmost_left_edge = - std::max(other_box.x_min, cur_box.x_min); - const IntegerValue leftmost_right_edge = - std::min(other_box.x_max, cur_box.x_max); - if (rightmost_left_edge < leftmost_right_edge) { - if (lower_top < higher_bottom && - higher_bottom - lower_top < min_y_size) { - add_box({.x_min = rightmost_left_edge, - .x_max = leftmost_right_edge, - .y_min = lower_top, - .y_max = higher_bottom}); - } - } - if (higher_bottom < lower_top) { - if (leftmost_right_edge < rightmost_left_edge && - rightmost_left_edge - leftmost_right_edge < min_x_size) { - add_box({.x_min = leftmost_right_edge, - .x_max = rightmost_left_edge, - .y_min = higher_bottom, - .y_max = lower_top}); - } - } + // Look for horizontal gaps that are too narrow to place any non-fixed item. + // We can write directly into `optional_boxes` instead of calling `add_box` + // because FindEmptySpacesHorizontally is guaranteed to return disjoint + // rectangles that are also disjoint from the input `dead_spaces`. + for (const auto& space : + FindEmptySpacesHorizontally(bounding_box, optional_boxes)) { + if (space.rect.x_max - space.rect.x_min < min_x_size) { + optional_boxes.push_back(space.rect); } } + + // Look for vertical gaps that are too short to place any non-fixed item. + for (const auto& space : + FindEmptySpacesVertically(bounding_box, optional_boxes)) { + if (space.rect.y_max - space.rect.y_min < min_y_size) { + optional_boxes.push_back(space.rect); + } + } + optional_boxes.erase(optional_boxes.begin(), optional_boxes.begin() + num_optional_boxes_to_remove); + return optional_boxes; } @@ -348,14 +307,16 @@ bool ReduceNumberofBoxesGreedy(std::vector* mandatory_rectangles, std::vector* optional_rectangles) { // The current implementation just greedly merge rectangles that shares an // edge. - std::vector> rectangle_storage; + std::deque rectangle_storage; enum class OptionalEnum { OPTIONAL, MANDATORY }; // bool for is_optional std::vector> rectangles_ptr; - absl::flat_hash_map top_edges_to_rectangle; - absl::flat_hash_map bottom_edges_to_rectangle; - absl::flat_hash_map left_edges_to_rectangle; - absl::flat_hash_map right_edges_to_rectangle; + const int total_rectangles = + mandatory_rectangles->size() + optional_rectangles->size(); + absl::flat_hash_map top_edges_to_rectangle(total_rectangles); + absl::flat_hash_map bottom_edges_to_rectangle(total_rectangles); + absl::flat_hash_map left_edges_to_rectangle(total_rectangles); + absl::flat_hash_map right_edges_to_rectangle(total_rectangles); bool changed_optional = false; bool changed_mandatory = false; @@ -438,8 +399,8 @@ bool ReduceNumberofBoxesGreedy(std::vector* mandatory_rectangles, changed_optional || (rectangles_ptr[i].second == OptionalEnum::OPTIONAL || rectangles_ptr[index].second == OptionalEnum::OPTIONAL); - rectangle_storage.push_back(std::make_unique(rectangle)); - Rectangle& new_rectangle = *rectangle_storage.back(); + rectangle_storage.push_back(rectangle); + Rectangle& new_rectangle = rectangle_storage.back(); new_rectangle.GrowToInclude(*rectangles_ptr[index].first); remove_rectangle(i); remove_rectangle(index); @@ -472,7 +433,8 @@ Neighbours BuildNeighboursGraph(absl::Span rectangles) { // To build a graph of neighbours, we build a sorted vector for each one of // the edges (top, bottom, etc) of the rectangles. Then we merge the bottom // and top vectors and iterate on it. Due to the sorting order, segments where - // the bottom of a rectangle touches the top of another one must consecutive. + // the bottom of a rectangle touches the top of another one must be + // consecutive in the vector. std::vector> edges_to_rectangle[4]; std::vector> neighbours; neighbours.reserve(2 * rectangles.size()); @@ -504,59 +466,64 @@ Neighbours BuildNeighboursGraph(absl::Span rectangles) { absl::c_sort(edges_to_rectangle[edge_position], cmp); } - constexpr struct EdgeData { - EdgePosition edge; - EdgePosition opposite_edge; - bool (*cmp)(const Edge&, const Edge&); - } edge_data[4] = {{.edge = EdgePosition::BOTTOM, - .opposite_edge = EdgePosition::TOP, - .cmp = &Edge::CompareYThenX}, - {.edge = EdgePosition::TOP, - .opposite_edge = EdgePosition::BOTTOM, - .cmp = &Edge::CompareYThenX}, - {.edge = EdgePosition::LEFT, - .opposite_edge = EdgePosition::RIGHT, - .cmp = &Edge::CompareXThenY}, - {.edge = EdgePosition::RIGHT, - .opposite_edge = EdgePosition::LEFT, - .cmp = &Edge::CompareXThenY}}; - - for (int edge_int = 0; edge_int < 4; ++edge_int) { - const EdgePosition edge_position = edge_data[edge_int].edge; - const EdgePosition opposite_edge_position = - edge_data[edge_int].opposite_edge; + for (const bool bottom_vs_top : {true, false}) { + // bottom_vs_top: compare bottom of a box with top of another one. + // !bottom_vs_top: compare left of a box with right of another one. + const EdgePosition edge_position = + bottom_vs_top ? EdgePosition::BOTTOM : EdgePosition::LEFT; + const EdgePosition opposite = + bottom_vs_top ? EdgePosition::TOP : EdgePosition::RIGHT; auto it = edges_to_rectangle[edge_position].begin(); - for (const auto& [edge, index] : - edges_to_rectangle[opposite_edge_position]) { - while (it != edges_to_rectangle[edge_position].end() && - edge_data[edge_int].cmp(it->first, edge)) { - ++it; - } - if (it == edges_to_rectangle[edge_position].end()) { - break; + const auto end = edges_to_rectangle[edge_position].end(); + + for (const auto& [edge, index] : edges_to_rectangle[opposite]) { + // Advance the main iterator only past intervals that finish completely + // before our current 'edge' starts. + while (it != end) { + if (bottom_vs_top) { // BOTTOM vs TOP. + if (it->first.y_start < edge.y_start || + (it->first.y_start == edge.y_start && + it->first.x_start + it->first.size <= edge.x_start)) { + ++it; + continue; + } + } else { // LEFT vs RIGHT. + if (it->first.x_start < edge.x_start || + (it->first.x_start == edge.x_start && + it->first.y_start + it->first.size <= edge.y_start)) { + ++it; + continue; + } + } + break; // Stop advancing main iterator. } - if (edge_position == EdgePosition::BOTTOM || - edge_position == EdgePosition::TOP) { - while (it != edges_to_rectangle[edge_position].end() && - it->first.y_start == edge.y_start && - it->first.x_start < edge.x_start + edge.size) { - neighbours.push_back({index, opposite_edge_position, it->second}); - neighbours.push_back({it->second, edge_position, index}); - ++it; + + // Use a temporary scan iterator to find all overlaps for this specific + // edge. + auto scan_it = it; + if (bottom_vs_top) { // Vertical edges. + while (scan_it != end && scan_it->first.y_start == edge.y_start && + scan_it->first.x_start < edge.x_start + edge.size) { + neighbours.push_back({index, opposite, scan_it->second}); + neighbours.push_back({scan_it->second, edge_position, index}); + ++scan_it; } - } else { - while (it != edges_to_rectangle[edge_position].end() && - it->first.x_start == edge.x_start && - it->first.y_start < edge.y_start + edge.size) { - neighbours.push_back({index, opposite_edge_position, it->second}); - neighbours.push_back({it->second, edge_position, index}); - ++it; + } else { // Horizontal edges. + while (scan_it != end && scan_it->first.x_start == edge.x_start && + scan_it->first.y_start < edge.y_start + edge.size) { + neighbours.push_back({index, opposite, scan_it->second}); + neighbours.push_back({scan_it->second, edge_position, index}); + ++scan_it; } } } } - gtl::STLSortAndRemoveDuplicates(&neighbours); + using EdgeTuple = std::tuple; + // We don't expect any duplicate neighbours. + DCHECK_EQ(neighbours.size(), + absl::flat_hash_set(neighbours.begin(), neighbours.end()) + .size()); return Neighbours(rectangles, neighbours); } @@ -590,34 +557,6 @@ std::vector> SplitInConnectedComponents( } namespace { -IntegerValue GetClockwiseStart(EdgePosition edge, const Rectangle& rectangle) { - switch (edge) { - case EdgePosition::LEFT: - return rectangle.y_min; - case EdgePosition::RIGHT: - return rectangle.y_max; - case EdgePosition::BOTTOM: - return rectangle.x_max; - case EdgePosition::TOP: - return rectangle.x_min; - } - LOG(FATAL) << "Invalid edge position: " << static_cast(edge); -} - -IntegerValue GetClockwiseEnd(EdgePosition edge, const Rectangle& rectangle) { - switch (edge) { - case EdgePosition::LEFT: - return rectangle.y_max; - case EdgePosition::RIGHT: - return rectangle.y_min; - case EdgePosition::BOTTOM: - return rectangle.x_min; - case EdgePosition::TOP: - return rectangle.x_max; - } - LOG(FATAL) << "Invalid edge position: " << static_cast(edge); -} - // Given a list of rectangles and their neighbours graph, find the list of // vertical and horizontal segments that touches a single rectangle edge. Or, // view in another way, the pieces of an edge that is touching the empty space. @@ -636,202 +575,203 @@ void GetAllSegmentsTouchingVoid( std::vector>& vertical_edges_on_boundary, std::vector>& horizontal_edges_on_boundary) { for (int i = 0; i < rectangles.size(); ++i) { - const Rectangle& rectangle = rectangles[i]; - for (int edge_int = 0; edge_int < 4; ++edge_int) { - const EdgePosition edge = static_cast(edge_int); - const auto box_neighbors = neighbours.GetSortedNeighbors(i, edge); - if (box_neighbors.empty()) { - if (edge == EdgePosition::LEFT || edge == EdgePosition::RIGHT) { - vertical_edges_on_boundary.push_back( - {Edge::GetEdge(rectangle, edge), i}); - } else { - horizontal_edges_on_boundary.push_back( - {Edge::GetEdge(rectangle, edge), i}); - } - continue; - } - IntegerValue previous_pos = GetClockwiseStart(edge, rectangle); - for (int n = 0; n <= box_neighbors.size(); ++n) { - IntegerValue neighbor_start; - const Rectangle* neighbor; - if (n == box_neighbors.size()) { - // On the last iteration we consider instead of the next neighbor the - // end of the current box. - neighbor_start = GetClockwiseEnd(edge, rectangle); - } else { - const int neighbor_idx = box_neighbors[n]; - neighbor = &rectangles[neighbor_idx]; - neighbor_start = GetClockwiseStart(edge, *neighbor); + const Rectangle& rect = rectangles[i]; + + auto process_edge = [&](EdgePosition pos, IntegerValue min_bound, + IntegerValue max_bound, IntegerValue fixed_pos) { + absl::Span neighbors = neighbours.GetSortedNeighbors(i, pos); + + // GetSortedNeighbors() returns the neighbors in clockwise order. + const bool is_descending = + (pos == EdgePosition::BOTTOM || pos == EdgePosition::RIGHT); + const bool is_vertical = + (pos == EdgePosition::LEFT || pos == EdgePosition::RIGHT); + IntegerValue current_min = min_bound; + + auto add_result = [&](IntegerValue start, IntegerValue end) { + const Edge e = is_vertical ? Edge{.x_start = fixed_pos, + .y_start = start, + .size = end - start} + : Edge{.x_start = start, + .y_start = fixed_pos, + .size = end - start}; + std::vector>& edges = + is_vertical ? vertical_edges_on_boundary + : horizontal_edges_on_boundary; + edges.push_back({e, i}); + }; + + auto process_neighbor = [&](int neighbor_idx) { + const Rectangle& neighbor = rectangles[neighbor_idx]; + const IntegerValue n_min = + is_vertical ? neighbor.y_min : neighbor.x_min; + const IntegerValue n_max = + is_vertical ? neighbor.y_max : neighbor.x_max; + + if (n_min > current_min) { + add_result(current_min, n_min); } - switch (edge) { - case EdgePosition::LEFT: - if (neighbor_start > previous_pos) { - vertical_edges_on_boundary.push_back( - {Edge{.x_start = rectangle.x_min, - .y_start = previous_pos, - .size = neighbor_start - previous_pos}, - i}); - } - break; - case EdgePosition::RIGHT: - if (neighbor_start < previous_pos) { - vertical_edges_on_boundary.push_back( - {Edge{.x_start = rectangle.x_max, - .y_start = neighbor_start, - .size = previous_pos - neighbor_start}, - i}); - } - break; - case EdgePosition::BOTTOM: - if (neighbor_start < previous_pos) { - horizontal_edges_on_boundary.push_back( - {Edge{.x_start = neighbor_start, - .y_start = rectangle.y_min, - .size = previous_pos - neighbor_start}, - i}); - } - break; - case EdgePosition::TOP: - if (neighbor_start > previous_pos) { - horizontal_edges_on_boundary.push_back( - {Edge{.x_start = previous_pos, - .y_start = rectangle.y_max, - .size = neighbor_start - previous_pos}, - i}); - } - break; + current_min = std::max(current_min, n_max); + }; + + if (is_descending) { + for (auto it = neighbors.rbegin(); it != neighbors.rend(); ++it) { + process_neighbor(*it); } - if (n != box_neighbors.size()) { - previous_pos = GetClockwiseEnd(edge, *neighbor); + } else { + for (const int neighbor_idx : neighbors) { + process_neighbor(neighbor_idx); } } - } + + if (current_min < max_bound) { + add_result(current_min, max_bound); + } + }; + + process_edge(EdgePosition::LEFT, rect.y_min, rect.y_max, rect.x_min); + process_edge(EdgePosition::RIGHT, rect.y_min, rect.y_max, rect.x_max); + process_edge(EdgePosition::BOTTOM, rect.x_min, rect.x_max, rect.y_min); + process_edge(EdgePosition::TOP, rect.x_min, rect.x_max, rect.y_max); } } +using Point = std::pair; + +// A directed rectangle edge. +struct BoundarySegment { + Point start; + Point destination; + int box_index; +}; + +// Unified transparent comparator for sorting and binary searching. +struct CompareBoundarySegment { + using is_transparent = void; + + bool operator()(const BoundarySegment& a, const BoundarySegment& b) const { + return std::tie(a.start, a.destination) < std::tie(b.start, b.destination); + } + bool operator()(const BoundarySegment& e, const Point& p) const { + return e.start < p; + } + bool operator()(const Point& p, const BoundarySegment& e) const { + return p < e.start; + } +}; + // Trace a boundary (interior or exterior) that contains the edge described by -// starting_edge_position and starting_step_point. This method removes the edges -// that were added to the boundary from `segments_to_follow`. -ShapePath TraceBoundary( - const EdgePosition& starting_edge_position, - std::pair starting_step_point, - std::array, - std::pair>, - 4>& segments_to_follow) { - // The boundary is composed of edges on the `segments_to_follow` map. So all - // we need is to find and glue them together on the right order. +// `initial_edge_idx`. This method sets to true in `edge_used` the indexes of +// the edges that were added to the boundary and skips edges that have been used +// before. +ShapePath TraceBoundary(int initial_edge_idx, + const std::vector& sorted_edges, + std::vector& edge_used) { ShapePath path; + int current_idx = initial_edge_idx; + const Point start_point = sorted_edges[current_idx].start; + Point prev = start_point; - auto extracted = - segments_to_follow[starting_edge_position].extract(starting_step_point); - CHECK(!extracted.empty()); - const int first_index = extracted.mapped().second; - - std::pair cur = starting_step_point; - int cur_index = first_index; - // Now we navigate from one edge to the next. To avoid going back, we remove - // used edges from the hash map. while (true) { - path.step_points.push_back(cur); + const BoundarySegment& current_seg = sorted_edges[current_idx]; + path.step_points.push_back(prev); + path.touching_box_index.push_back(current_seg.box_index); - bool can_go[4] = {false, false, false, false}; - EdgePosition direction_to_take = EdgePosition::LEFT; - for (int edge_int = 0; edge_int < 4; ++edge_int) { - const EdgePosition edge = static_cast(edge_int); - if (segments_to_follow[edge].contains(cur)) { - can_go[edge] = true; - direction_to_take = edge; - } - } + const Point cur = current_seg.destination; - if (can_go == absl::Span{false, false, false, false}) { - // Cannot move anywhere, we closed the loop. + edge_used[current_idx] = true; + + if (cur == start_point) { break; } - // Handle one pathological case. - if (can_go[EdgePosition::LEFT] && can_go[EdgePosition::RIGHT]) { - // Corner case (literally): - // ******** - // ******** - // ******** - // ******** - // ^ +++++++++ - // | +++++++++ - // | +++++++++ - // +++++++++ - // - // In this case we keep following the same box. - auto it_x = segments_to_follow[EdgePosition::LEFT].find(cur); - if (cur_index == it_x->second.second) { - direction_to_take = EdgePosition::LEFT; - } else { - direction_to_take = EdgePosition::RIGHT; - } - } else if (can_go[EdgePosition::TOP] && can_go[EdgePosition::BOTTOM]) { - auto it_y = segments_to_follow[EdgePosition::TOP].find(cur); - if (cur_index == it_y->second.second) { - direction_to_take = EdgePosition::TOP; - } else { - direction_to_take = EdgePosition::BOTTOM; - } + const auto lower_it = + std::lower_bound(sorted_edges.begin(), sorted_edges.end(), cur, + CompareBoundarySegment()); + const int start_idx = + static_cast(std::distance(sorted_edges.begin(), lower_it)); + + // We should have typically one outgoing edge, and at most two in case of a + // pinch-point like the one in the example below: + // ******** + // ******** + // ******** + // ******** + // ^ +++++++++ + // | +++++++++ + // | +++++++++ + // +++++++++ + // + int outgoing_edges = 0; + while (start_idx + outgoing_edges < sorted_edges.size() && + sorted_edges[start_idx + outgoing_edges].start == cur) { + ++outgoing_edges; } - auto extracted = segments_to_follow[direction_to_take].extract(cur); - cur_index = extracted.mapped().second; - switch (direction_to_take) { - case EdgePosition::LEFT: - cur.first -= extracted.mapped().first; - segments_to_follow[EdgePosition::RIGHT].erase( - cur); // Forbid going back - break; - case EdgePosition::RIGHT: - cur.first += extracted.mapped().first; - segments_to_follow[EdgePosition::LEFT].erase(cur); // Forbid going back - break; - case EdgePosition::TOP: - cur.second += extracted.mapped().first; - segments_to_follow[EdgePosition::BOTTOM].erase( - cur); // Forbid going back - break; - case EdgePosition::BOTTOM: - cur.second -= extracted.mapped().first; - segments_to_follow[EdgePosition::TOP].erase(cur); // Forbid going back - break; - } - path.touching_box_index.push_back(cur_index); + DCHECK_LE(outgoing_edges, 2); + + // If we are in a pinch-point and have two options, prefer the one that + // keeps following the same box. + const int next_idx = + (outgoing_edges == 2 && + sorted_edges[start_idx].box_index != current_seg.box_index + ? start_idx + 1 + : start_idx); + + DCHECK(!edge_used[next_idx]); + + prev = cur; + current_idx = next_idx; } - path.touching_box_index.push_back(cur_index); return path; } + } // namespace std::vector BoxesToShapes(absl::Span rectangles, const Neighbours& neighbours) { std::vector> vertical_edges_on_boundary; std::vector> horizontal_edges_on_boundary; + GetAllSegmentsTouchingVoid(rectangles, neighbours, vertical_edges_on_boundary, horizontal_edges_on_boundary); - std::array, - std::pair>, - 4> - segments_to_follow; + std::vector sorted_edges; + sorted_edges.reserve(vertical_edges_on_boundary.size() + + horizontal_edges_on_boundary.size()); + // We pick the right orientation for each edge so that we only produce + // paths in the clockwise direction. for (const auto& [edge, box_index] : vertical_edges_on_boundary) { - segments_to_follow[EdgePosition::TOP][{edge.x_start, edge.y_start}] = { - edge.size, box_index}; - segments_to_follow[EdgePosition::BOTTOM][{ - edge.x_start, edge.y_start + edge.size}] = {edge.size, box_index}; + const Rectangle& rect = rectangles[box_index]; + if (edge.x_start == rect.x_max) { + sorted_edges.push_back({{edge.x_start, edge.y_start + edge.size}, + {edge.x_start, edge.y_start}, + box_index}); + } else { + sorted_edges.push_back({{edge.x_start, edge.y_start}, + {edge.x_start, edge.y_start + edge.size}, + box_index}); + } } + for (const auto& [edge, box_index] : horizontal_edges_on_boundary) { - segments_to_follow[EdgePosition::RIGHT][{edge.x_start, edge.y_start}] = { - edge.size, box_index}; - segments_to_follow[EdgePosition::LEFT][{ - edge.x_start + edge.size, edge.y_start}] = {edge.size, box_index}; + const Rectangle& rect = rectangles[box_index]; + if (edge.y_start == rect.y_max) { + sorted_edges.push_back({{edge.x_start, edge.y_start}, + {edge.x_start + edge.size, edge.y_start}, + box_index}); + } else { + sorted_edges.push_back({{edge.x_start + edge.size, edge.y_start}, + {edge.x_start, edge.y_start}, + box_index}); + } } + absl::c_sort(sorted_edges, CompareBoundarySegment()); + std::vector edge_used(sorted_edges.size(), false); + const auto components = SplitInConnectedComponents(neighbours); std::vector result(components.size()); std::vector box_to_component(rectangles.size()); @@ -840,96 +780,188 @@ std::vector BoxesToShapes(absl::Span rectangles, box_to_component[box_index] = i; } } - while (!segments_to_follow[EdgePosition::LEFT].empty()) { - // Get edge most to the bottom left - const int box_index = - segments_to_follow[EdgePosition::RIGHT].begin()->second.second; - const std::pair starting_step_point = - segments_to_follow[EdgePosition::RIGHT].begin()->first; + + for (int i = 0; i < sorted_edges.size(); ++i) { + if (edge_used[i]) continue; + + const int box_index = sorted_edges[i].box_index; const int component_index = box_to_component[box_index]; - // The left-most vertical edge of the connected component must be of its - // exterior boundary. So we must always see the exterior boundary before - // seeing any holes. + // The lexicographical sort (min X, min Y) acts as a free topological test. + // For any given component, the very first unvisited edge we encounter is + // mathematically guaranteed to be its absolute bottom-left extremity. + // Because an extreme outer point cannot exist inside a hole, the first + // path we extract for a component is always its exterior boundary. + // Any subsequent paths found for this component must therefore be holes. const bool is_hole = !result[component_index].boundary.step_points.empty(); ShapePath& path = is_hole ? result[component_index].holes.emplace_back() : result[component_index].boundary; - path = TraceBoundary(EdgePosition::RIGHT, starting_step_point, - segments_to_follow); - if (is_hole) { - // Follow the usual convention that holes are in the inverse orientation - // of the external boundary. - absl::c_reverse(path.step_points); - absl::c_reverse(path.touching_box_index); - } + + // Note that by keeping the solid on the right we are getting the exterior + // boundary with clockwise orientation and the interior holes with + // counter-clockwise orientation naturally. + path = TraceBoundary(i, sorted_edges, edge_used); } + return result; } namespace { -struct PolygonCut { - std::pair start; - std::pair end; - int start_index; - int end_index; - - struct CmpByStartY { - bool operator()(const PolygonCut& a, const PolygonCut& b) const { - return std::tie(a.start.second, a.start.first) < - std::tie(b.start.second, b.start.first); - } - }; - struct CmpByEndY { - bool operator()(const PolygonCut& a, const PolygonCut& b) const { - return std::tie(a.end.second, a.end.first) < - std::tie(b.end.second, b.end.first); +EdgePosition GetSegmentDirection( + const std::pair& curr_segment, + const std::pair& next_segment) { + if (curr_segment.first == next_segment.first) { + return next_segment.second > curr_segment.second ? EdgePosition::TOP + : EdgePosition::BOTTOM; + } else { + return next_segment.first > curr_segment.first ? EdgePosition::RIGHT + : EdgePosition::LEFT; + } +} + +struct WallEvent { + IntegerValue min_y; + IntegerValue max_y; + int wall_start_idx; +}; + +struct RayStartEvent { + IntegerValue y; + int ray_start_idx; +}; + +struct SweepEvent { + IntegerValue x; + // Note the order: WallEvent is index 0, RayStartEvent is index 1. + std::variant payload; + + bool operator<(const SweepEvent& other) const { + if (x != other.x) return x < other.x; + // Process Walls (index 0) before RayStartEvents (index 1) at the same X. + return payload.index() < other.payload.index(); + } +}; + +std::vector ExtractRightCuts(FlatShape& shape) { + std::vector cuts; + const int initial_points = shape.points.size(); + std::vector events; + events.reserve(2 * initial_points); + + // Gather all events. + for (int i = 0; i < initial_points; i++) { + const auto previous = shape.points[i]; + const auto it = shape.points[shape.next[i]]; + const auto next_segment = shape.points[shape.next[shape.next[i]]]; + + const EdgePosition previous_dir = GetSegmentDirection(previous, it); + const EdgePosition next_dir = GetSegmentDirection(it, next_segment); + + // Identify walls (downward vertical segments). + if (previous_dir == EdgePosition::BOTTOM) { + events.push_back( + {.x = previous.first, + .payload = WallEvent{.min_y = std::min(previous.second, it.second), + .max_y = std::max(previous.second, it.second), + .wall_start_idx = i}}); } - }; - struct CmpByStartX { - bool operator()(const PolygonCut& a, const PolygonCut& b) const { - return a.start < b.start; + // Identify ray starts. + if (previous_dir != next_dir) { + if ((previous_dir == EdgePosition::TOP && + next_dir == EdgePosition::LEFT) || + (previous_dir == EdgePosition::RIGHT && + next_dir == EdgePosition::TOP)) { + events.push_back({.x = it.first, + .payload = RayStartEvent{ + .y = it.second, .ray_start_idx = shape.next[i]}}); + } } - }; + } + // Sort events left-to-right. + absl::c_sort(events); + + // Maps Y-coordinate to the starting vertex index of the ray. + absl::btree_map active_rays; + + // Sweep Line pass. + for (const SweepEvent& event : events) { + if (const auto* ray_event = std::get_if(&event.payload)) { + // A new ray starts shooting right. + active_rays[ray_event->y] = ray_event->ray_start_idx; + } else if (const auto* wall_event = + std::get_if(&event.payload)) { + // Find all active rays whose Y coordinate hits this wall. + auto it_start = active_rays.lower_bound(wall_event->min_y); + auto it_end = active_rays.upper_bound(wall_event->max_y); + + for (auto it = it_start; it != it_end; ++it) { + PolygonCut cut; + cut.start = shape.points[it->second]; + cut.end = {event.x, it->first}; + cut.start_index = it->second; + cut.end_index = wall_event->wall_start_idx; + cuts.push_back(cut); + } + active_rays.erase(it_start, it_end); + } + } - struct CmpByEndX { - bool operator()(const PolygonCut& a, const PolygonCut& b) const { - return a.end < b.end; + // Cut the shape along the detected cuts. + const auto apply_cut = [&shape](int seg, + std::pair pt) { + while (true) { + const auto cur = shape.points[seg]; + const auto next = shape.points[shape.next[seg]]; + const IntegerValue min_y = std::min(cur.second, next.second); + const IntegerValue max_y = std::max(cur.second, next.second); + + if (min_y < pt.second && pt.second < max_y) { + shape.points.push_back(pt); + const int next_idx = shape.next[seg]; + shape.next[seg] = shape.points.size() - 1; + shape.next.push_back(next_idx); + return static_cast(shape.points.size() - 1); + } + if (cur == pt) return seg; + if (next == pt) return shape.next[seg]; + seg = shape.next[seg]; } }; - template - friend void AbslStringify(Sink& sink, const PolygonCut& diagonal) { - absl::Format(&sink, "(%v,%v)-(%v,%v)", diagonal.start.first, - diagonal.start.second, diagonal.end.first, - diagonal.end.second); + for (PolygonCut& ray : cuts) { + if (ray.end_index != -1) ray.end_index = apply_cut(ray.end_index, ray.end); } -}; -// A different representation of a shape. The two vectors must have the same -// size. The first one contains the points of the shape and the second one -// contains the index of the next point in the shape. -// -// Note that we code in this file is only correct for shapes with points -// connected only by horizontal or vertical lines. -struct FlatShape { - std::vector> points; - std::vector next; -}; + return cuts; +} -EdgePosition GetSegmentDirection( - const std::pair& curr_segment, - const std::pair& next_segment) { - if (curr_segment.first == next_segment.first) { - return next_segment.second > curr_segment.second ? EdgePosition::TOP - : EdgePosition::BOTTOM; - } else { - return next_segment.first > curr_segment.first ? EdgePosition::RIGHT - : EdgePosition::LEFT; +// Helper functions for rotating a single point. +std::pair RotatePoint90CW( + const std::pair& p) { + return {p.second, -p.first}; +} + +std::pair RotatePoint90CCW( + const std::pair& p) { + return {-p.second, p.first}; +} + +std::pair RotatePoint180( + const std::pair& p) { + return {-p.first, -p.second}; +} + +// Helper to rotate the entire shape in place. +void RotateShape90CW(FlatShape& shape) { + for (auto& p : shape.points) { + p = RotatePoint90CW(p); } } +} // namespace + // Given a polygon, this function returns all line segments that start on a // concave vertex and follow horizontally or vertically until it reaches the // border of the polygon. This function returns all such segments grouped on the @@ -941,295 +973,44 @@ EdgePosition GetSegmentDirection( // even if they are not edges. std::array, 4> GetPotentialPolygonCuts( FlatShape& shape) { - std::array, 4> cuts; + std::array, 4> all_cuts; - // First, for each concave vertex we create a cut that starts at it and - // crosses the polygon until infinite (in practice, int_max/int_min). - for (int i = 0; i < shape.points.size(); i++) { - const auto& it = &shape.points[shape.next[i]]; - const auto& previous = &shape.points[i]; - const auto& next_segment = &shape.points[shape.next[shape.next[i]]]; - const EdgePosition previous_dir = GetSegmentDirection(*previous, *it); - const EdgePosition next_dir = GetSegmentDirection(*it, *next_segment); - - if ((previous_dir == EdgePosition::TOP && next_dir == EdgePosition::LEFT) || - (previous_dir == EdgePosition::RIGHT && - next_dir == EdgePosition::TOP)) { - cuts[EdgePosition::RIGHT].push_back( - {.start = *it, - .end = {std::numeric_limits::max(), it->second}, - .start_index = shape.next[i]}); - } - if ((previous_dir == EdgePosition::BOTTOM && - next_dir == EdgePosition::RIGHT) || - (previous_dir == EdgePosition::LEFT && - next_dir == EdgePosition::BOTTOM)) { - cuts[EdgePosition::LEFT].push_back( - {.start = {std::numeric_limits::min(), it->second}, - .end = *it, - .end_index = shape.next[i]}); - } - if ((previous_dir == EdgePosition::RIGHT && - next_dir == EdgePosition::TOP) || - (previous_dir == EdgePosition::BOTTOM && - next_dir == EdgePosition::RIGHT)) { - cuts[EdgePosition::BOTTOM].push_back( - {.start = {it->first, std::numeric_limits::min()}, - .end = *it, - .end_index = shape.next[i]}); - } - if ((previous_dir == EdgePosition::TOP && next_dir == EdgePosition::LEFT) || - (previous_dir == EdgePosition::LEFT && - next_dir == EdgePosition::BOTTOM)) { - cuts[EdgePosition::TOP].push_back( - {.start = *it, - .end = {it->first, std::numeric_limits::max()}, - .start_index = shape.next[i]}); - } - } + // We will simply call ExtractRightCuts() four times, doing the right + // rotations before and after the call. + all_cuts[EdgePosition::RIGHT] = ExtractRightCuts(shape); - // Now that we have one of the points of the segment (the one starting on a - // vertex), we need to find the other point. This is basically finding the - // first path segment that crosses each cut connecting edge->infinity we - // collected above. We do a rather naive implementation of that below and its - // complexity is O(N^2) even if it should be fast in most cases. If it - // turns out to be costly on profiling we can use a more sophisticated - // algorithm for finding the first intersection. - - // We need to sort the cuts so we can use binary search to quickly find cuts - // that cross a segment. - std::sort(cuts[EdgePosition::RIGHT].begin(), cuts[EdgePosition::RIGHT].end(), - PolygonCut::CmpByStartY()); - std::sort(cuts[EdgePosition::LEFT].begin(), cuts[EdgePosition::LEFT].end(), - PolygonCut::CmpByEndY()); - std::sort(cuts[EdgePosition::BOTTOM].begin(), - cuts[EdgePosition::BOTTOM].end(), PolygonCut::CmpByEndX()); - std::sort(cuts[EdgePosition::TOP].begin(), cuts[EdgePosition::TOP].end(), - PolygonCut::CmpByStartX()); - - // This function cuts a segment in two if it crosses a cut. In any case, it - // returns the index of a point `point_idx` so that `shape.points[point_idx] - // == point_to_cut`. - const auto cut_segment_if_necessary = - [&shape](int segment_idx, - std::pair point_to_cut) { - const auto& cur = shape.points[segment_idx]; - const auto& next = shape.points[shape.next[segment_idx]]; - if (cur.second == next.second) { - DCHECK_EQ(point_to_cut.second, cur.second); - // We have a horizontal segment - const IntegerValue edge_start = std::min(cur.first, next.first); - const IntegerValue edge_end = std::max(cur.first, next.first); - - if (edge_start < point_to_cut.first && - point_to_cut.first < edge_end) { - shape.points.push_back(point_to_cut); - const int next_idx = shape.next[segment_idx]; - shape.next[segment_idx] = shape.points.size() - 1; - shape.next.push_back(next_idx); - return static_cast(shape.points.size() - 1); - } - return (shape.points[segment_idx] == point_to_cut) - ? segment_idx - : shape.next[segment_idx]; - } else { - DCHECK_EQ(cur.first, next.first); - DCHECK_EQ(point_to_cut.first, cur.first); - // We have a vertical segment - const IntegerValue edge_start = std::min(cur.second, next.second); - const IntegerValue edge_end = std::max(cur.second, next.second); - - if (edge_start < point_to_cut.second && - point_to_cut.second < edge_end) { - shape.points.push_back(point_to_cut); - const int next_idx = shape.next[segment_idx]; - shape.next[segment_idx] = shape.points.size() - 1; - shape.next.push_back(next_idx); - return static_cast(shape.points.size() - 1); - } - return (shape.points[segment_idx] == point_to_cut) - ? segment_idx - : shape.next[segment_idx]; - } - }; + RotateShape90CW(shape); + auto top_cuts = ExtractRightCuts(shape); + for (auto& cut : top_cuts) { + cut.start = RotatePoint90CCW(cut.start); + cut.end = RotatePoint90CCW(cut.end); + } + all_cuts[EdgePosition::TOP] = top_cuts; - for (int i = 0; i < shape.points.size(); i++) { - const auto* cur_point_ptr = &shape.points[shape.next[i]]; - const auto* previous = &shape.points[i]; - DCHECK(cur_point_ptr->first == previous->first || - cur_point_ptr->second == previous->second) - << "found a segment that is neither horizontal nor vertical"; - const EdgePosition direction = - GetSegmentDirection(*previous, *cur_point_ptr); - - if (direction == EdgePosition::BOTTOM) { - const auto cut_start = absl::c_lower_bound( - cuts[EdgePosition::RIGHT], - PolygonCut{.start = {std::numeric_limits::min(), - cur_point_ptr->second}}, - PolygonCut::CmpByStartY()); - auto cut_end = absl::c_upper_bound( - cuts[EdgePosition::RIGHT], - PolygonCut{.start = {std::numeric_limits::max(), - previous->second}}, - PolygonCut::CmpByStartY()); - - for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { - PolygonCut& diagonal = *cut_it; - const IntegerValue diagonal_start_x = diagonal.start.first; - const IntegerValue diagonal_cur_end_x = diagonal.end.first; - // Our binary search guarantees those two conditions. - DCHECK_LE(cur_point_ptr->second, diagonal.start.second); - DCHECK_LE(diagonal.start.second, previous->second); - - // Let's test if the diagonal crosses the current boundary segment - if (diagonal_start_x <= previous->first && - diagonal_cur_end_x > cur_point_ptr->first) { - DCHECK_LT(diagonal_start_x, cur_point_ptr->first); - DCHECK_LE(previous->first, diagonal_cur_end_x); - - diagonal.end.first = cur_point_ptr->first; - - diagonal.end_index = cut_segment_if_necessary(i, diagonal.end); - DCHECK(shape.points[diagonal.end_index] == diagonal.end); - - // Subtle: cut_segment_if_necessary might add new points to the vector - // of the shape, so the pointers computed from it might become - // invalid. Moreover, the current segment now is shorter, so we need - // to update our upper bound. - cur_point_ptr = &shape.points[shape.next[i]]; - previous = &shape.points[i]; - cut_end = absl::c_upper_bound( - cuts[EdgePosition::RIGHT], - PolygonCut{.start = {std::numeric_limits::max(), - previous->second}}, - PolygonCut::CmpByStartY()); - } - } - } + RotateShape90CW(shape); + auto left_cuts = ExtractRightCuts(shape); + for (auto& cut : left_cuts) { + cut.start = RotatePoint180(cut.start); + cut.end = RotatePoint180(cut.end); + } + all_cuts[EdgePosition::LEFT] = left_cuts; - if (direction == EdgePosition::TOP) { - const auto cut_start = absl::c_lower_bound( - cuts[EdgePosition::LEFT], - PolygonCut{.end = {std::numeric_limits::min(), - previous->second}}, - PolygonCut::CmpByEndY()); - auto cut_end = absl::c_upper_bound( - cuts[EdgePosition::LEFT], - PolygonCut{.end = {std::numeric_limits::max(), - cur_point_ptr->second}}, - PolygonCut::CmpByEndY()); - for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { - PolygonCut& diagonal = *cut_it; - const IntegerValue diagonal_start_x = diagonal.start.first; - const IntegerValue diagonal_cur_end_x = diagonal.end.first; - // Our binary search guarantees those two conditions. - DCHECK_LE(diagonal.end.second, cur_point_ptr->second); - DCHECK_LE(previous->second, diagonal.end.second); - - // Let's test if the diagonal crosses the current boundary segment - if (diagonal_start_x < cur_point_ptr->first && - previous->first <= diagonal_cur_end_x) { - DCHECK_LT(cur_point_ptr->first, diagonal_cur_end_x); - DCHECK_LE(diagonal_start_x, previous->first); - - diagonal.start.first = cur_point_ptr->first; - diagonal.start_index = cut_segment_if_necessary(i, diagonal.start); - DCHECK(shape.points[diagonal.start_index] == diagonal.start); - cur_point_ptr = &shape.points[shape.next[i]]; - previous = &shape.points[i]; - cut_end = absl::c_upper_bound( - cuts[EdgePosition::LEFT], - PolygonCut{.end = {std::numeric_limits::max(), - cur_point_ptr->second}}, - PolygonCut::CmpByEndY()); - } - } - } + RotateShape90CW(shape); + auto bottom_cuts = ExtractRightCuts(shape); + for (auto& cut : bottom_cuts) { + cut.start = RotatePoint90CW(cut.start); + cut.end = RotatePoint90CW(cut.end); + } + all_cuts[EdgePosition::BOTTOM] = bottom_cuts; - if (direction == EdgePosition::LEFT) { - const auto cut_start = absl::c_lower_bound( - cuts[EdgePosition::BOTTOM], - PolygonCut{.end = {cur_point_ptr->first, - std::numeric_limits::min()}}, - PolygonCut::CmpByEndX()); - auto cut_end = absl::c_upper_bound( - cuts[EdgePosition::BOTTOM], - PolygonCut{.end = {previous->first, - std::numeric_limits::max()}}, - PolygonCut::CmpByEndX()); - for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { - PolygonCut& diagonal = *cut_it; - const IntegerValue diagonal_start_y = diagonal.start.second; - const IntegerValue diagonal_cur_end_y = diagonal.end.second; - - // Our binary search guarantees those two conditions. - DCHECK_LE(cur_point_ptr->first, diagonal.end.first); - DCHECK_LE(diagonal.end.first, previous->first); - - // Let's test if the diagonal crosses the current boundary segment - if (diagonal_start_y < cur_point_ptr->second && - cur_point_ptr->second <= diagonal_cur_end_y) { - DCHECK_LE(diagonal_start_y, previous->second); - DCHECK_LT(cur_point_ptr->second, diagonal_cur_end_y); - - diagonal.start.second = cur_point_ptr->second; - diagonal.start_index = cut_segment_if_necessary(i, diagonal.start); - DCHECK(shape.points[diagonal.start_index] == diagonal.start); - cur_point_ptr = &shape.points[shape.next[i]]; - previous = &shape.points[i]; - cut_end = absl::c_upper_bound( - cuts[EdgePosition::BOTTOM], - PolygonCut{.end = {previous->first, - std::numeric_limits::max()}}, - PolygonCut::CmpByEndX()); - } - } - } + // Restore shape to original orientation. + RotateShape90CW(shape); - if (direction == EdgePosition::RIGHT) { - const auto cut_start = absl::c_lower_bound( - cuts[EdgePosition::TOP], - PolygonCut{.start = {previous->first, - std::numeric_limits::min()}}, - PolygonCut::CmpByStartX()); - auto cut_end = absl::c_upper_bound( - cuts[EdgePosition::TOP], - PolygonCut{.start = {cur_point_ptr->first, - std::numeric_limits::max()}}, - PolygonCut::CmpByStartX()); - for (auto cut_it = cut_start; cut_it < cut_end; ++cut_it) { - PolygonCut& diagonal = *cut_it; - const IntegerValue diagonal_start_y = diagonal.start.second; - const IntegerValue diagonal_cur_end_y = diagonal.end.second; - - // Our binary search guarantees those two conditions. - DCHECK_LE(previous->first, diagonal.start.first); - DCHECK_LE(diagonal.start.first, cur_point_ptr->first); - - // Let's test if the diagonal crosses the current boundary segment - if (diagonal_start_y <= cur_point_ptr->second && - cur_point_ptr->second < diagonal_cur_end_y) { - DCHECK_LT(diagonal_start_y, previous->second); - DCHECK_LE(cur_point_ptr->second, diagonal_cur_end_y); - - diagonal.end.second = cur_point_ptr->second; - diagonal.end_index = cut_segment_if_necessary(i, diagonal.end); - DCHECK(shape.points[diagonal.end_index] == diagonal.end); - cur_point_ptr = &shape.points[shape.next[i]]; - cut_end = absl::c_upper_bound( - cuts[EdgePosition::TOP], - PolygonCut{.start = {cur_point_ptr->first, - std::numeric_limits::max()}}, - PolygonCut::CmpByStartX()); - previous = &shape.points[i]; - } - } - } - } - return cuts; + return all_cuts; } +namespace { + void CutShapeWithPolygonCuts(FlatShape& shape, absl::Span cuts) { std::vector previous(shape.points.size(), -1); @@ -1265,12 +1046,7 @@ void CutShapeWithPolygonCuts(FlatShape& shape, } } // namespace -// This function applies the method described in page 3 of [1]. -// -// [1] Eppstein, David. "Graph-theoretic solutions to computational geometry -// problems." International Workshop on Graph-Theoretic Concepts in Computer -// Science. Berlin, Heidelberg: Springer Berlin Heidelberg, 2009. -std::vector CutShapeIntoRectangles(SingleShape shape) { +FlatShape BuildFlatShape(const SingleShape& shape) { auto is_aligned = [](const std::pair& p1, const std::pair& p2, const std::pair& p3) { @@ -1295,7 +1071,7 @@ std::vector CutShapeIntoRectangles(SingleShape shape) { // To cut our polygon into rectangles, we first put it into a data structure // that is easier to manipulate. FlatShape flat_shape; - for (int i = 0; 1 + i < shape.boundary.step_points.size(); ++i) { + for (int i = 0; i < shape.boundary.step_points.size(); ++i) { const std::pair& segment = shape.boundary.step_points[i]; add_segment(segment, 0, flat_shape.points, flat_shape.next); @@ -1303,8 +1079,8 @@ std::vector CutShapeIntoRectangles(SingleShape shape) { flat_shape.next.back() = 0; for (const ShapePath& hole : shape.holes) { const int start = flat_shape.next.size(); - if (hole.step_points.size() < 2) continue; - for (int i = 0; i + 1 < hole.step_points.size(); ++i) { + if (hole.step_points.empty()) continue; + for (int i = 0; i < hole.step_points.size(); ++i) { const std::pair& segment = hole.step_points[i]; add_segment(segment, start, flat_shape.points, flat_shape.next); @@ -1312,6 +1088,16 @@ std::vector CutShapeIntoRectangles(SingleShape shape) { flat_shape.next.back() = start; } + return flat_shape; +} + +// This function applies the method described in page 3 of [1]. +// +// [1] Eppstein, David. "Graph-theoretic solutions to computational geometry +// problems." International Workshop on Graph-Theoretic Concepts in Computer +// Science. Berlin, Heidelberg: Springer Berlin Heidelberg, 2009. +std::vector CutShapeIntoRectangles(const SingleShape& shape) { + FlatShape flat_shape = BuildFlatShape(shape); std::array, 4> all_cuts = GetPotentialPolygonCuts(flat_shape); @@ -1323,16 +1109,31 @@ std::vector CutShapeIntoRectangles(SingleShape shape) { // ours that we are only talking of diagonals that are not "diagonal" in the // usual meaning of the word: ie., horizontal or vertical segments connecting // two vertices of the polygon). + + absl::flat_hash_set< + std::tuple> + seen_diagonals; std::array, 2> good_diagonals; for (const auto& d : all_cuts[EdgePosition::BOTTOM]) { - if (absl::c_binary_search(all_cuts[EdgePosition::TOP], d, - PolygonCut::CmpByStartX())) { + seen_diagonals.insert( + {d.start.first, d.start.second, d.end.first, d.end.second}); + } + for (const auto& d : all_cuts[EdgePosition::TOP]) { + // Good horizontal diagonals appear twice: once in BOTTOM and once, + // reversed, in TOP. + if (seen_diagonals.contains( + {d.end.first, d.end.second, d.start.first, d.start.second})) { good_diagonals[0].push_back(d); } } + seen_diagonals.clear(); for (const auto& d : all_cuts[EdgePosition::LEFT]) { - if (absl::c_binary_search(all_cuts[EdgePosition::RIGHT], d, - PolygonCut::CmpByStartY())) { + seen_diagonals.insert( + {d.start.first, d.start.second, d.end.first, d.end.second}); + } + for (const auto& d : all_cuts[EdgePosition::RIGHT]) { + if (seen_diagonals.contains( + {d.end.first, d.end.second, d.start.first, d.start.second})) { good_diagonals[1].push_back(d); } } @@ -1387,10 +1188,14 @@ std::vector CutShapeIntoRectangles(SingleShape shape) { // Remove duplicates coming from good diagonals first. std::vector cuts = all_cuts[EdgePosition::TOP]; - for (const auto& cut : all_cuts[EdgePosition::BOTTOM]) { - if (!absl::c_binary_search(all_cuts[EdgePosition::TOP], cut, - PolygonCut::CmpByStartX())) { - cuts.push_back(cut); + for (const auto& d : all_cuts[EdgePosition::TOP]) { + seen_diagonals.insert( + {d.start.first, d.start.second, d.end.first, d.end.second}); + } + for (const auto& d : all_cuts[EdgePosition::BOTTOM]) { + if (!seen_diagonals.contains( + {d.end.first, d.end.second, d.start.first, d.start.second})) { + cuts.push_back(d); } } @@ -1437,8 +1242,13 @@ bool ReduceNumberOfBoxesExactMandatory( for (const Rectangle& box : *mandatory_rectangles) { mandatory_bounding_box.GrowToInclude(box); } - const std::vector mandatory_empty_holes = - FindEmptySpaces(mandatory_bounding_box, *mandatory_rectangles); + const std::vector empty_spaces = FindEmptySpacesHorizontally( + mandatory_bounding_box, *mandatory_rectangles); + std::vector mandatory_empty_holes; + mandatory_empty_holes.reserve(empty_spaces.size()); + for (const auto& space : empty_spaces) { + mandatory_empty_holes.push_back(space.rect); + } const std::vector> mandatory_holes_components = SplitInConnectedComponents(BuildNeighboursGraph(mandatory_empty_holes)); diff --git a/ortools/sat/2d_rectangle_presolve.h b/ortools/sat/2d_rectangle_presolve.h index 19a4cbe6661..54c112fd37f 100644 --- a/ortools/sat/2d_rectangle_presolve.h +++ b/ortools/sat/2d_rectangle_presolve.h @@ -14,17 +14,17 @@ #ifndef ORTOOLS_SAT_2D_RECTANGLE_PRESOLVE_H_ #define ORTOOLS_SAT_2D_RECTANGLE_PRESOLVE_H_ +#include #include #include #include #include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" #include "absl/log/log.h" #include "absl/types/span.h" #include "ortools/sat/diffn_util.h" #include "ortools/sat/integer_base.h" +#include "ortools/sat/util.h" namespace operations_research { namespace sat { @@ -190,11 +190,21 @@ class Neighbours { absl::Span rectangles, absl::Span> neighbors) : size_(rectangles.size()) { + int total_neighbors_per_position[4] = {0}; for (const auto& [box_index, edge, neighbor] : neighbors) { - neighbors_[edge][box_index].push_back(neighbor); + total_neighbors_per_position[edge]++; } + CompactVectorVectorBuilder builders[4]; for (int edge = 0; edge < 4; ++edge) { - for (auto& [box_index, neighbors] : neighbors_[edge]) { + builders[edge].ReserveNumItems(total_neighbors_per_position[edge]); + } + for (const auto& [box_index, edge, neighbor] : neighbors) { + builders[edge].Add(box_index, neighbor); + } + for (int edge = 0; edge < 4; ++edge) { + neighbors_[edge].ResetFromBuilder(builders[edge], size_); + for (int box_index = 0; box_index < size_; ++box_index) { + absl::Span neighbors = neighbors_[edge][box_index]; absl::c_sort(neighbors, [&rectangles, edge](int a, int b) { return CompareClockwise(static_cast(edge))( rectangles[a], rectangles[b]); @@ -208,16 +218,11 @@ class Neighbours { // Neighbors are sorted in the clockwise order. absl::Span GetSortedNeighbors(int rectangle_index, EdgePosition edge) const { - if (auto it = neighbors_[edge].find(rectangle_index); - it != neighbors_[edge].end()) { - return it->second; - } else { - return {}; - } + return neighbors_[edge][rectangle_index]; } private: - absl::flat_hash_map> neighbors_[4]; + CompactVectorVector neighbors_[4]; int size_; }; @@ -233,24 +238,65 @@ std::vector> SplitInConnectedComponents( // union of rectangles, the path is a subset of the union of all the rectangle's // edges. struct ShapePath { - // The two vectors should have exactly the same size. + // The two vectors have exactly the same size, representing a closed cyclic + // loop. The last point conceptually connects back to the first point. std::vector> step_points; // touching_box_index[i] contains the index of the unique interior rectangle // touching the segment step_points[i]->step_points[(i+1)%size]. std::vector touching_box_index; }; +// Represents a single contiguous orthogonal shape (a connected component of +// rectangles) defined by its outer envelope and any internal voids. +// +// Topological orientation ("solid on right"): +// - `boundary`: The exterior envelope, traced in clockwise order. +// - `holes`: The internal voids, naturally traced in counter-clockwise +// order to maintain the interior solid mass on the right side of the path. struct SingleShape { ShapePath boundary; std::vector holes; }; -// Given a set of rectangles, split it into connected components and transform -// each individual set into a shape described by its boundary and holes paths. +// Given a set of non-overlapping rectangles, splits them into connected +// components with two rectangles being connected if they share a side of +// non-zero length and transforms each set into a shape described by its +// boundary and hole paths. Then, for each component, convert the set of +// rectangles into a single shape with a boundary and zero or more holes. +// +// Note that if two rectangles touch only at a single diagonal point, there are +// two possible ways to continue the contour. We make the choice of continuing +// on the same rectangle it was touching before the pinch-point. std::vector BoxesToShapes(absl::Span rectangles, const Neighbours& neighbours); -std::vector CutShapeIntoRectangles(SingleShape shapes); +// Given a single shape, returns a set of non-overlapping rectangles that +// is a paving of the shape. This function always returns the minimum possible +// number of rectangles that can exactly pave the shape. +std::vector CutShapeIntoRectangles(const SingleShape& shapes); + +struct PolygonCut { + std::pair start; + std::pair end; + int start_index; + int end_index; +}; + +// A different representation of a shape. The two vectors must have the same +// size. The first one contains the points of the shape and the second one +// contains the index of the next point in the shape. +// +// Note that we code in this file is only correct for shapes with points +// connected only by horizontal or vertical lines. +struct FlatShape { + std::vector> points; + std::vector next; +}; + +// Exposed for testing, documented on the .cc file. +FlatShape BuildFlatShape(const SingleShape& shape); +std::array, 4> GetPotentialPolygonCuts( + FlatShape& shape); } // namespace sat } // namespace operations_research diff --git a/ortools/sat/2d_rectangle_presolve_test.cc b/ortools/sat/2d_rectangle_presolve_test.cc index f02a3338ce6..353458ab813 100644 --- a/ortools/sat/2d_rectangle_presolve_test.cc +++ b/ortools/sat/2d_rectangle_presolve_test.cc @@ -14,6 +14,8 @@ #include "ortools/sat/2d_rectangle_presolve.h" #include +#include +#include #include #include #include @@ -30,7 +32,9 @@ #include "absl/log/log.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "gtest/gtest.h" #include "ortools/base/gmock.h" @@ -268,9 +272,7 @@ std::string RenderContour(std::optional bb, std::pair p = path.step_points[i]; ss << " p" << i << "[pos=\"" << 2 * p.first << "," << 2 * p.second << "!\" shape=point]\n"; - if (i != path.step_points.size() - 1) { - ss << " p" << i << "->p" << i + 1 << "\n"; - } + ss << " p" << i << "->p" << ((i + 1) % path.step_points.size()) << "\n"; } return RenderDot(bb, rectangles, ss.str()); } @@ -652,9 +654,9 @@ ShapePath TraceBoundary( next_direction = EdgePosition::BOTTOM; } } else if (touching_edge[EdgePosition::TOP]) { - next_direction = EdgePosition::LEFT; - } else if (touching_edge[EdgePosition::BOTTOM]) { next_direction = EdgePosition::RIGHT; + } else if (touching_edge[EdgePosition::BOTTOM]) { + next_direction = EdgePosition::LEFT; } else { LOG(FATAL) << "TraceBoundary() got a `starting_step_point` that is not in an edge " @@ -671,6 +673,8 @@ ShapePath TraceBoundary( result.step_points.back() == result.step_points.front() && point.x == result.step_points[1].first && point.y == result.step_points[1].second) { + result.step_points.pop_back(); + result.touching_box_index.pop_back(); break; } if (!result.step_points.empty() && @@ -702,7 +706,8 @@ ShapePath TraceBoundary( std::string RenderShapes(std::optional bb, absl::Span rectangles, - absl::Span shapes) { + absl::Span shapes, + std::string_view extra_dot_payload = "") { const std::vector colors = {"black", "white", "orange", "cyan", "yellow", "purple"}; std::stringstream ss; @@ -734,9 +739,41 @@ std::string RenderShapes(std::optional bb, } } } + + ss << extra_dot_payload; + return RenderDot(bb, rectangles, ss.str()); } +std::string RenderCuts(const std::array, 4>& cuts) { + std::stringstream ss; + + // Distinct colors for TOP, RIGHT, BOTTOM, LEFT + const std::string_view cut_colors[] = {"red", "blue", "green", "magenta"}; + + for (int dir = 0; dir < 4; ++dir) { + for (int i = 0; i < cuts[dir].size(); ++i) { + const auto& cut = cuts[dir][i]; + + // Give nodes unique names to avoid colliding with RenderShapes' 'p' nodes + ss << " cut_" << dir << "_" << i << "_start[pos=\"" + << 2 * cut.start.first << "," << 2 * cut.start.second + << "!\" shape=point width=0 height=0]\n"; + + ss << " cut_" << dir << "_" << i << "_end[pos=\"" << 2 * cut.end.first + << "," << 2 * cut.end.second << "!\" shape=point width=0 height=0]\n"; + + // Use penwidth=15 so it's thinner than boundaries, and dashed so we can + // see through it + ss << " cut_" << dir << "_" << i << "_start->cut_" << dir << "_" << i + << "_end" + << " [color=\"" << cut_colors[dir] + << "\", penwidth=15, style=dashed];\n"; + } + } + return ss.str(); +} + TEST(ContourTest, Random) { constexpr int kNumRuns = 100; absl::BitGen bit_gen; @@ -758,36 +795,37 @@ TEST(ContourTest, Random) { const Neighbours neighbours = BuildNeighboursGraph(fixed_rectangles); const auto components = SplitInConnectedComponents(neighbours); const Rectangle bb = {.x_min = 0, .x_max = 100, .y_min = 0, .y_max = 100}; - int min_index = -1; - std::pair min_coord = { - std::numeric_limits::max(), - std::numeric_limits::max()}; - for (const int box_index : components[0]) { - const Rectangle& rectangle = fixed_rectangles[box_index]; - if (std::make_pair(rectangle.x_min, rectangle.y_min) < min_coord) { - min_coord = {rectangle.x_min, rectangle.y_min}; - min_index = box_index; - } - } const std::vector shapes = BoxesToShapes(fixed_rectangles, neighbours); - for (const SingleShape& shape : shapes) { - const ShapePath& boundary = shape.boundary; - const ShapePath expected_shape = - TraceBoundary(boundary.step_points[0], boundary.touching_box_index[0], + + // A helper lambda to test any ShapePath (boundary or hole) + auto verify_path = [&](const ShapePath& path, + const std::string& path_name) { + if (path.step_points.empty()) return; + + const ShapePath expected_path = + TraceBoundary(path.step_points[0], path.touching_box_index[0], fixed_rectangles, neighbours); - if (boundary.step_points != expected_shape.step_points) { - LOG(ERROR) << "Fast algo:\n" - << RenderContour(bb, fixed_rectangles, boundary); - LOG(ERROR) << "Naive algo:\n" - << RenderContour(bb, fixed_rectangles, expected_shape); + + if (path.step_points != expected_path.step_points) { + LOG(ERROR) << "Fast algo (" << path_name << "):\n" + << RenderContour(bb, fixed_rectangles, path); + LOG(ERROR) << "Naive algo (" << path_name << "):\n" + << RenderContour(bb, fixed_rectangles, expected_path); LOG(FATAL) << "Found different solutions between naive and fast algo!"; } - EXPECT_EQ(boundary.step_points, expected_shape.step_points); - EXPECT_EQ(boundary.touching_box_index, expected_shape.touching_box_index); - } + EXPECT_EQ(path.step_points, expected_path.step_points); + EXPECT_EQ(path.touching_box_index, expected_path.touching_box_index); + }; + // Test both the exterior boundaries AND the internal holes + for (const SingleShape& shape : shapes) { + verify_path(shape.boundary, "Exterior Boundary"); + for (int i = 0; i < shape.holes.size(); ++i) { + verify_path(shape.holes[i], absl::StrCat("Hole ", i)); + } + } if (run == 0) { LOG(INFO) << RenderShapes(bb, fixed_rectangles, shapes); } @@ -800,48 +838,301 @@ TEST(ContourTest, SimpleShapes) { {.x_min = 3, .x_max = 8, .y_min = 0, .y_max = 10}}; ShapePath shape = TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); - EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 0, 1, 1, 1, 0, 0, 0)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 0, 1, 1, 1, 0, 0)); EXPECT_THAT(shape.step_points, ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), std::make_pair(10, 10), std::make_pair(8, 10), std::make_pair(8, 0), std::make_pair(3, 0), - std::make_pair(3, 10), std::make_pair(0, 10), - std::make_pair(0, 20))); + std::make_pair(3, 10), std::make_pair(0, 10))); rectangles = {{.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}, {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}}; shape = TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); - EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 1, 1, 1, 0, 0)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 1, 1, 1, 0)); EXPECT_THAT(shape.step_points, ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), std::make_pair(10, 10), std::make_pair(10, 0), - std::make_pair(0, 0), std::make_pair(0, 10), - std::make_pair(0, 20))); + std::make_pair(0, 0), std::make_pair(0, 10))); rectangles = {{.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}, {.x_min = 0, .x_max = 15, .y_min = 0, .y_max = 10}}; shape = TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); - EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 1, 1, 1, 1, 0, 0)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 1, 1, 1, 1, 0)); EXPECT_THAT(shape.step_points, ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), std::make_pair(10, 10), std::make_pair(15, 10), std::make_pair(15, 0), std::make_pair(0, 0), - std::make_pair(0, 10), std::make_pair(0, 20))); + std::make_pair(0, 10))); rectangles = {{.x_min = 0, .x_max = 10, .y_min = 10, .y_max = 20}, {.x_min = 0, .x_max = 10, .y_min = 0, .y_max = 10}, {.x_min = 10, .x_max = 20, .y_min = 0, .y_max = 10}}; shape = TraceBoundary({0, 20}, 0, rectangles, BuildNeighboursGraph(rectangles)); - EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 2, 2, 2, 1, 1, 0, 0)); + EXPECT_THAT(shape.touching_box_index, ElementsAre(0, 0, 2, 2, 2, 1, 1, 0)); EXPECT_THAT(shape.step_points, ElementsAre(std::make_pair(0, 20), std::make_pair(10, 20), std::make_pair(10, 10), std::make_pair(20, 10), std::make_pair(20, 0), std::make_pair(10, 0), - std::make_pair(0, 0), std::make_pair(0, 10), - std::make_pair(0, 20))); + std::make_pair(0, 0), std::make_pair(0, 10))); +} + +// Naive but long and inefficient implementation of GetPotentialPolygonCuts. +std::array, 4> GetPotentialPolygonCutsNaive( + FlatShape& shape) { + std::array, 4> cuts; + + int original_N = shape.points.size(); + if (original_N == 0) return cuts; + + // 1. Build a `prev` lookup array to easily access the previous vertex + // for any given vertex without an expensive linear search. + std::vector prev(original_N); + for (int i = 0; i < original_N; ++i) { + prev[shape.next[i]] = i; + } + + // We only iterate up to `original_N`. Any points added during edge-splitting + // are perfectly straight/collinear and cannot be concave vertices. + for (int i = 0; i < original_N; ++i) { + auto A = shape.points[prev[i]]; + auto B = shape.points[i]; + auto C = shape.points[shape.next[i]]; + + // Vector u = A -> B, Vector v = B -> C + IntegerValue ux = B.first - A.first; + IntegerValue uy = B.second - A.second; + IntegerValue vx = C.first - B.first; + IntegerValue vy = C.second - B.second; + + // 2. Identify Concave Vertices + // Cross product: ux*vy - uy*vx. + // Given "solid on right", a strictly positive cross product means a left + // turn, which corresponds to a 270-degree concave interior angle. + IntegerValue cross_product = ux * vy - uy * vx; + + if (cross_product > 0) { + std::vector inward_dirs; + + // 3. Determine the two inward shooting directions. + // For a concave vertex, these are the forward continuations of A->B and + // C->B. Continuation of A->B (same direction as u) + if (ux > 0) + inward_dirs.push_back(EdgePosition::RIGHT); + else if (ux < 0) + inward_dirs.push_back(EdgePosition::LEFT); + else if (uy > 0) + inward_dirs.push_back(EdgePosition::TOP); // Assumes +Y is Up + else if (uy < 0) + inward_dirs.push_back(EdgePosition::BOTTOM); // Assumes -Y is Down + + // Continuation of C->B (opposite direction of v) + if (vx > 0) + inward_dirs.push_back(EdgePosition::LEFT); + else if (vx < 0) + inward_dirs.push_back(EdgePosition::RIGHT); + else if (vy > 0) + inward_dirs.push_back(EdgePosition::BOTTOM); + else if (vy < 0) + inward_dirs.push_back(EdgePosition::TOP); + + // 4. Cast rays in both inward directions to find the closest opposite + // edge + for (EdgePosition dir : inward_dirs) { + IntegerValue min_dist = std::numeric_limits::max(); + int best_edge = -1; + std::pair hit_point; + + // Re-evaluate size on every ray cast to include dynamically split + // edges! + int current_N = shape.points.size(); + + for (int k = 0; k < current_N; ++k) { + auto P1 = shape.points[k]; + auto P2 = shape.points[shape.next[k]]; + + // Check for intersection with perpendicular edges ahead of the ray + if (dir == EdgePosition::RIGHT) { + if (P1.first == P2.first && + P1.first > B.first) { // Vertical edge to the right + if (std::min(P1.second, P2.second) <= B.second && + B.second <= std::max(P1.second, P2.second)) { + IntegerValue dist = P1.first - B.first; + if (dist < min_dist) { + min_dist = dist; + best_edge = k; + hit_point = {P1.first, B.second}; + } + } + } + } else if (dir == EdgePosition::LEFT) { + if (P1.first == P2.first && + P1.first < B.first) { // Vertical edge to the left + if (std::min(P1.second, P2.second) <= B.second && + B.second <= std::max(P1.second, P2.second)) { + IntegerValue dist = B.first - P1.first; + if (dist < min_dist) { + min_dist = dist; + best_edge = k; + hit_point = {P1.first, B.second}; + } + } + } + } else if (dir == EdgePosition::TOP) { + if (P1.second == P2.second && + P1.second > B.second) { // Horizontal edge above + if (std::min(P1.first, P2.first) <= B.first && + B.first <= std::max(P1.first, P2.first)) { + IntegerValue dist = P1.second - B.second; + if (dist < min_dist) { + min_dist = dist; + best_edge = k; + hit_point = {B.first, P1.second}; + } + } + } + } else if (dir == EdgePosition::BOTTOM) { + if (P1.second == P2.second && + P1.second < B.second) { // Horizontal edge below + if (std::min(P1.first, P2.first) <= B.first && + B.first <= std::max(P1.first, P2.first)) { + IntegerValue dist = B.second - P1.second; + if (dist < min_dist) { + min_dist = dist; + best_edge = k; + hit_point = {B.first, P1.second}; + } + } + } + } + } + + // 5. Apply the cut and split the edge if necessary + if (best_edge != -1) { + int end_idx = -1; + auto P1 = shape.points[best_edge]; + auto P2 = shape.points[shape.next[best_edge]]; + + if (hit_point == P1) { + end_idx = best_edge; + } else if (hit_point == P2) { + end_idx = shape.next[best_edge]; + } else { + // The cut landed in the middle of a segment. Split it! + end_idx = shape.points.size(); + shape.points.push_back(hit_point); + + // Maintain the circular linked list geometry + shape.next.push_back(shape.next[best_edge]); + shape.next[best_edge] = end_idx; + } + + PolygonCut cut; + cut.start = B; + cut.end = hit_point; + cut.start_index = i; + cut.end_index = end_idx; + + cuts[dir].push_back(cut); + } + } + } + } + + return cuts; +} + +TEST(GetPotentialPolygonCutsTest, Random) { + constexpr int kNumRuns = 1000; + absl::BitGen bit_gen; + + for (int run = 0; run < kNumRuns; ++run) { + // Start by generating a feasible problem that we know the solution with + // some items fixed. + std::vector input = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, bit_gen); + std::shuffle(input.begin(), input.end(), bit_gen); + const int num_fixed_rectangles = input.size() * 2 / 3; + const absl::Span fixed_rectangles = + absl::MakeConstSpan(input).subspan(0, num_fixed_rectangles); + const absl::Span other_rectangles = + absl::MakeSpan(input).subspan(num_fixed_rectangles); + const std::vector input_in_range = + MakeItemsFromRectangles(other_rectangles, 0.6, bit_gen); + + const Neighbours neighbours = BuildNeighboursGraph(fixed_rectangles); + const auto components = SplitInConnectedComponents(neighbours); + std::vector shapes = + BoxesToShapes(fixed_rectangles, neighbours); + for (auto& shape : shapes) { + FlatShape flat_shape = BuildFlatShape(shape); + FlatShape flat_shape_naive = BuildFlatShape(shape); + auto cuts = GetPotentialPolygonCuts(flat_shape); + auto cuts_naive = GetPotentialPolygonCutsNaive(flat_shape_naive); + for (auto* cut_sol_ptr : {&cuts, &cuts_naive}) { + auto& cut_sol = *cut_sol_ptr; + for (int i = 0; i < 4; ++i) { + absl::c_sort( + cut_sol[i], [](const PolygonCut& a, const PolygonCut& b) { + return std::tie(a.start, a.end) < std::tie(b.start, b.end); + }); + } + } + for (int i = 0; i < 4; ++i) { + auto check = [&](bool condition, absl::string_view message) { + if (!condition) { + std::cerr << message << "\nOptimized:\n" + << RenderShapes(std::nullopt, fixed_rectangles, shapes, + RenderCuts(cuts)) + << "Naive:\n" + << RenderShapes(std::nullopt, fixed_rectangles, shapes, + RenderCuts(cuts_naive)); + LOG(FATAL) << message; + } + }; + // 1. Both algorithms MUST find the exact same number of concave cuts + check(cuts[i].size() == cuts_naive[i].size(), + absl::StrCat("Mismatch in cut count for direction ", i)); + + for (int k = 0; k < cuts[i].size(); ++k) { + const auto& opt_cut = cuts[i][k]; + const auto& naive_cut = cuts_naive[i][k]; + + // 2. Prevent segfaults by guaranteeing indices were correctly + // assigned + ASSERT_GE(opt_cut.start_index, 0); + ASSERT_GE(opt_cut.end_index, 0); + ASSERT_GE(naive_cut.start_index, 0); + ASSERT_GE(naive_cut.end_index, 0); + + // 3. The physical coordinates of the rays must be strictly identical + check(opt_cut.start == naive_cut.start, + absl::StrCat("Ray start coordinate mismatch in direction ", i)); + check(opt_cut.end == naive_cut.end, + absl::StrCat("Ray end coordinate mismatch in direction ", i)); + + // 4. The Graph Topology check (The most important part!) + // This proves that regardless of mutation order or array length, + // the linked list graph points exactly to the expected geometric + // coordinate. + check(flat_shape.points[opt_cut.start_index] == + flat_shape_naive.points[naive_cut.start_index], + "Ray start index coordinate mismatch in direction"); + + check(flat_shape.points[opt_cut.end_index] == + flat_shape_naive.points[naive_cut.end_index], + "Ray end index coordinate mismatch in direction"); + + // 5. Internal struct consistency + // Ensure the pointer actually leads to the recorded coordinate + CHECK_EQ(flat_shape.points[opt_cut.start_index], opt_cut.start); + CHECK_EQ(flat_shape.points[opt_cut.end_index], opt_cut.end); + } + } + } + } } TEST(ContourTest, ExampleFromPaper) { diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 8eef76bff82..e5c196a1866 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -35,6 +35,7 @@ cc_library( ":cp_model_utils", ":model", ":sat_parameters_cc_proto", + "//ortools/base:types", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log:check", @@ -59,6 +60,7 @@ cc_test( "//ortools/base:gmock_main", "//ortools/base:log_severity", "//ortools/base:parse_test_proto", + "//ortools/base:types", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", @@ -134,6 +136,7 @@ cc_library( "//ortools/base:log_severity", "//ortools/base:stl_util", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/graph_base:connected_components", "//ortools/util:bitset", "//ortools/util:logging", @@ -152,6 +155,49 @@ cc_library( ], ) +cc_library( + name = "scheduling_local_search", + srcs = ["scheduling_local_search.cc"], + hdrs = ["scheduling_local_search.h"], + deps = [ + ":combine_solutions", + ":cp_model_checker", + ":cp_model_utils", + ":integer_base", + ":scheduling_model", + ":stat_tables", + ":subsolver", + ":synchronization", + ":util", + "//ortools/base:types", + "//ortools/graph_base:topologicalsorter", + "//ortools/util:random_engine", + "//ortools/util:time_limit", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/log", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/log:vlog_is_on", + "@abseil-cpp//absl/random:bit_gen_ref", + "@abseil-cpp//absl/random:distributions", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings:string_view", + "@abseil-cpp//absl/types:span", + ], +) + +cc_test( + name = "scheduling_local_search_test", + srcs = ["scheduling_local_search_test.cc"], + deps = [ + ":integer_base", + ":scheduling_local_search", + ":scheduling_model", + ":util", + "//ortools/base:gmock_main", + "@abseil-cpp//absl/types:span", + ], +) + proto_library( name = "cp_model_proto", srcs = ["cp_model.proto"], @@ -174,7 +220,9 @@ cc_library( ":scheduling_helpers", ":synchronization", "//ortools/base:stl_util", + "//ortools/util:time_limit", "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/cleanup", "@abseil-cpp//absl/container:fixed_array", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:inlined_vector", @@ -217,12 +265,17 @@ cc_library( ":sat_parameters_cc_proto", ":synchronization", "//ortools/base:protobuf_util", + "//ortools/base:types", "//ortools/base/macros:os_support", + "//ortools/util:logging", + "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/types:span", "@protobuf", @@ -253,6 +306,7 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_utils", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/algorithm:container", @@ -303,6 +357,7 @@ cc_library( deps = [ ":sat_base", "//ortools/base:stl_util", + "//ortools/base:types", "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/types:span", @@ -346,11 +401,13 @@ cc_library( ":cp_model_utils", ":sat_base", ":util", + "//ortools/base:file", "//ortools/base:stl_util", "//ortools/base:strong_vector", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/log:vlog_is_on", "@abseil-cpp//absl/numeric:bits", "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/random:distributions", @@ -453,6 +510,7 @@ cc_library( ":integer_base", ":presolve_context", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/algorithm:container", @@ -512,14 +570,22 @@ cc_library( hdrs = ["scheduling_model.h"], deps = [ ":cp_model_cc_proto", + ":cp_model_utils", + ":diffn_util", ":integer_base", ":util", + "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/graph_base:connected_components", + "//ortools/graph_base:topologicalsorter", + "//ortools/util:bitset", + "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/status:statusor", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/strings:str_format", "@abseil-cpp//absl/types:span", @@ -557,6 +623,7 @@ cc_library( "//ortools/base:file", "//ortools/base:hash", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/base/macros:os_support", "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", @@ -609,6 +676,7 @@ cc_library( "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/base:timer", + "//ortools/base:types", "//ortools/base/macros:os_support", "//ortools/util:bitset", "//ortools/util:logging", @@ -648,6 +716,7 @@ cc_test( ":util", "//ortools/base:gmock_main", "//ortools/base:parse_test_proto", + "//ortools/base:types", "//ortools/util:random_engine", "@abseil-cpp//absl/time", "@abseil-cpp//absl/types:span", @@ -665,6 +734,7 @@ cc_library( ":diffn_util", ":primary_variables", ":sat_parameters_cc_proto", + "//ortools/base:types", "//ortools/port:proto_utils", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", @@ -688,6 +758,7 @@ cc_test( ":cp_model_checker", "//ortools/base:gmock_main", "//ortools/base:parse_test_proto", + "//ortools/base:types", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/types:span", ], @@ -701,11 +772,15 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_utils", + ":diffn_util", + ":integer_base", ":sat_parameters_cc_proto", ":util", "//ortools/base:mathutil", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/graph_base:strongly_connected_components", + "//ortools/port:proto_utils", "//ortools/util:bitset", "//ortools/util:dense_set", "//ortools/util:saturated_arithmetic", @@ -716,6 +791,7 @@ cc_library( "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings", "@abseil-cpp//absl/types:span", ], ) @@ -749,6 +825,7 @@ cc_library( ":cp_model_cc_proto", ":cp_model_checker", ":cp_model_copy", + ":cp_model_utils", ":integer_base", ":linear_model", ":restart", @@ -758,11 +835,15 @@ cc_library( ":synchronization", ":util", "//ortools/algorithms:binary_search", + "//ortools/base:log_severity", + "//ortools/base:types", "//ortools/util:logging", + "//ortools/util:random_engine", "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", "//ortools/util:time_limit", "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/flags:flag", "@abseil-cpp//absl/functional:any_invocable", "@abseil-cpp//absl/functional:bind_front", "@abseil-cpp//absl/functional:function_ref", @@ -823,6 +904,7 @@ cc_library( deps = [ ":cp_model_search", ":sat_parameters_cc_proto", + "//ortools/base:types", "@abseil-cpp//absl/strings", ], ) @@ -856,6 +938,7 @@ cc_library( ":sat_parameters_cc_proto", ":util", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/util:strong_integers", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/container:flat_hash_map", @@ -880,6 +963,7 @@ cc_test( ":sat_parameters_cc_proto", "//ortools/base:gmock_main", "//ortools/base:parse_test_proto", + "//ortools/base:types", "//ortools/util:stats", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log", @@ -936,6 +1020,8 @@ cc_library( ":sat_inprocessing", ":sat_parameters_cc_proto", ":sat_solver", + ":scheduling_local_search", + ":scheduling_model", ":simplification", ":stat_tables", ":subsolver", @@ -953,6 +1039,7 @@ cc_library( "//ortools/base:threadpool", "//ortools/base:timer", "//ortools/base:types", + "//ortools/base:version", "//ortools/base/macros:buildenv", "//ortools/base/macros:os_support", "//ortools/graph_base:connected_components", @@ -1076,6 +1163,7 @@ cc_library( ":sat_inprocessing", ":sat_parameters_cc_proto", ":sat_solver", + ":scheduling_local_search", ":shaving_solver", ":simplification", ":stat_tables", @@ -1092,6 +1180,7 @@ cc_library( "//ortools/base:threadpool", "//ortools/base:timer", "//ortools/base:types", + "//ortools/base:version", "//ortools/base/macros:buildenv", "//ortools/base/macros:os_support", "//ortools/graph_base:connected_components", @@ -1235,6 +1324,7 @@ cc_library( deps = [ ":cp_model_cc_proto", ":cp_model_utils", + "//ortools/base:types", "@abseil-cpp//absl/random", ], ) @@ -1248,6 +1338,7 @@ cc_library( ":cp_model_utils", ":util", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", @@ -1304,6 +1395,7 @@ cc_library( ":util", "//ortools/base:base_export", "//ortools/base:log_severity", + "//ortools/base:types", "//ortools/port:proto_utils", "//ortools/util:affine_relation", "//ortools/util:bitset", @@ -1379,6 +1471,7 @@ cc_library( ":cp_model_utils", ":presolve_context", "//ortools/base:stl_util", + "//ortools/base:types", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/container:inlined_vector", @@ -1414,6 +1507,7 @@ cc_library( ":solution_crush", "//ortools/base:log_severity", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/algorithm:container", @@ -1458,6 +1552,7 @@ cc_library( ":sat_inprocessing", ":sat_parameters_cc_proto", ":sat_solver", + ":scheduling_model", ":simplification", ":solution_crush", ":util", @@ -1468,6 +1563,7 @@ cc_library( "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/base:timer", + "//ortools/base:types", "//ortools/graph_base:strongly_connected_components", "//ortools/graph_base:topologicalsorter", "//ortools/port:proto_utils", @@ -1564,6 +1660,7 @@ cc_library( ":cp_model_cc_proto", ":cp_model_utils", "//ortools/base:log_severity", + "//ortools/base:types", "//ortools/port:proto_utils", "//ortools/util:logging", "//ortools/util:sorted_interval_list", @@ -1600,6 +1697,7 @@ cc_library( ":sat_parameters_cc_proto", ":solution_crush", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/port:proto_utils", "//ortools/util:logging", "//ortools/util:sorted_interval_list", @@ -1636,6 +1734,7 @@ cc_test( "//ortools/util:sorted_interval_list", "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/random", "@abseil-cpp//absl/strings", ], ) @@ -1689,6 +1788,7 @@ cc_library( "//ortools/base:log_severity", "//ortools/base:stl_util", "//ortools/base:timer", + "//ortools/base:types", "//ortools/port:proto_utils", "//ortools/port:sysinfo", "//ortools/util:bitset", @@ -1816,6 +1916,7 @@ cc_library( "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/base:timer", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:integer_pq", "//ortools/util:logging", @@ -1956,6 +2057,7 @@ cc_library( "//ortools/base:stl_util", "//ortools/base:strong_vector", "//ortools/base:timer", + "//ortools/base:types", "//ortools/graph_base:strongly_connected_components", "//ortools/util:logging", "//ortools/util:strong_integers", @@ -2084,6 +2186,7 @@ cc_library( "//ortools/base:mathutil", "//ortools/base:stl_util", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:affine_relation", "//ortools/util:logging", "//ortools/util:saturated_arithmetic", @@ -2122,6 +2225,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":sat_base", + "//ortools/base:mathutil", "//ortools/base:strong_vector", "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", @@ -2162,6 +2266,7 @@ cc_library( ":sat_solver", "//ortools/base:log_severity", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:rev", "//ortools/util:saturated_arithmetic", @@ -2193,6 +2298,7 @@ cc_library( ":sat_solver", ":synchronization", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:bitset", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", @@ -2216,7 +2322,9 @@ cc_library( ":intervals", ":linear_constraint_manager", ":linear_programming_constraint", + ":linear_propagation", ":model", + ":precedences", ":probing", ":pseudo_costs", ":restart", @@ -2230,12 +2338,16 @@ cc_library( ":synchronization", ":util", "//ortools/base:log_severity", + "//ortools/base:strong_vector", "//ortools/util:strong_integers", "//ortools/util:time_limit", + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", + "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/random:distributions", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/types:span", @@ -2283,6 +2395,7 @@ cc_library( ":util", "//ortools/base:log_severity", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/glop:variables_info", "//ortools/util:logging", "//ortools/util:strong_integers", @@ -2531,6 +2644,7 @@ cc_test( ":sat_solver", "//ortools/base:gmock_main", "//ortools/base:log_severity", + "//ortools/base:types", "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", "@abseil-cpp//absl/log", @@ -2559,6 +2673,7 @@ cc_library( ":sat_solver", ":util", "//ortools/base:mathutil", + "//ortools/base:types", "//ortools/util:strong_integers", "//ortools/util:time_limit", "@abseil-cpp//absl/base:core_headers", @@ -2591,6 +2706,7 @@ cc_test( "//ortools/base:log_severity", "//ortools/base:parse_test_proto", "//ortools/base:parse_text_proto", + "//ortools/base:types", "//ortools/port:proto_utils", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", @@ -2617,10 +2733,12 @@ cc_library( ":model", ":precedences", ":sat_base", + ":sat_solver", ":synchronization", ":util", "//ortools/base:log_severity", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:rev", "//ortools/util:strong_integers", @@ -2634,6 +2752,7 @@ cc_library( "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", "@abseil-cpp//absl/numeric:int128", + "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/strings", "@abseil-cpp//absl/types:span", ], @@ -2727,7 +2846,6 @@ cc_library( "//ortools/util:sort", "//ortools/util:strong_integers", "//ortools/util:time_limit", - "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/cleanup", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", @@ -2969,6 +3087,7 @@ cc_library( "//ortools/base:mathutil", "//ortools/base:stl_util", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:logging", "//ortools/util:sorted_interval_list", "//ortools/util:strong_integers", @@ -3019,6 +3138,7 @@ cc_library( ":model", ":sat_base", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", "@abseil-cpp//absl/base:core_headers", @@ -3040,6 +3160,7 @@ cc_test( ":sat_base", "//ortools/base:gmock_main", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:strong_integers", "@abseil-cpp//absl/types:span", ], @@ -3051,6 +3172,7 @@ cc_library( hdrs = ["linear_programming_constraint.h"], deps = [ ":cp_model_cc_proto", + ":cp_model_checker", ":cp_model_mapping", ":cuts", ":implied_bounds", @@ -3070,6 +3192,7 @@ cc_library( "//ortools/algorithms:binary_search", "//ortools/base:log_severity", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/glop:parameters_cc_proto", "//ortools/glop:revised_simplex", "//ortools/glop:status", @@ -3134,6 +3257,7 @@ cc_library( ":util", "//ortools/base:hash", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/glop:variables_info", "//ortools/lp_data:lp_types", "//ortools/util:saturated_arithmetic", @@ -3185,6 +3309,7 @@ cc_library( ":synchronization", "//ortools/base:stl_util", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/lp_data:lp_types", "//ortools/util:saturated_arithmetic", "//ortools/util:sorted_interval_list", @@ -3264,6 +3389,7 @@ cc_library( "//ortools/base:mathutil", "//ortools/base:stl_util", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/graph:max_flow", "//ortools/graph_base:connected_components", "//ortools/graph_base:graph", @@ -3410,6 +3536,7 @@ cc_library( "//ortools/lp_data:lp_types", "//ortools/util:strong_integers", "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/types:span", ], ) @@ -3429,6 +3556,11 @@ cc_library( name = "lp_utils", srcs = ["lp_utils.cc"], hdrs = ["lp_utils.h"], + visibility = [ + "//ortools/bop:__pkg__", + "//third_party/ortools/ortools/linear_solver/proto_solver:__pkg__", + "//visibility:public", + ], deps = [ ":cp_model_cc_proto", ":cp_model_utils", @@ -3436,6 +3568,7 @@ cc_library( ":sat_parameters_cc_proto", "//ortools/base:log_severity", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/glop:lp_solver", "//ortools/glop:parameters_cc_proto", "//ortools/linear_solver:linear_solver_cc_proto", @@ -3554,6 +3687,7 @@ cc_library( ":sat_parameters_cc_proto", "//ortools/base:log_severity", "//ortools/base:mathutil", + "//ortools/base:types", "//ortools/base/macros:os_support", "//ortools/util:bitset", "//ortools/util:random_engine", @@ -3589,6 +3723,7 @@ cc_test( "//ortools/base:gmock_main", "//ortools/base:mathutil", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:random_engine", "//ortools/util:sorted_interval_list", @@ -3787,6 +3922,7 @@ cc_library( ":synchronization", ":util", "//ortools/base:constant_divisor", + "//ortools/base:types", "//ortools/util:bitset", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", @@ -3840,15 +3976,14 @@ cc_library( deps = [ ":diffn_util", ":integer_base", + ":util", "//ortools/base:log_severity", - "//ortools/base:stl_util", "//ortools/graph:minimum_vertex_cover", "//ortools/graph_base:strongly_connected_components", "@abseil-cpp//absl/algorithm:container", "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", - "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", @@ -3905,6 +4040,7 @@ cc_test( ":synchronization", "//ortools/algorithms:binary_search", "//ortools/base:gmock_main", + "//ortools/base:types", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/random", "@abseil-cpp//absl/random:bit_gen_ref", @@ -3975,6 +4111,7 @@ cc_test( ":util", "//ortools/base:gmock", "//ortools/base:gmock_main", + "//ortools/base:types", "//ortools/graph_base:connected_components", "//ortools/graph_base:strongly_connected_components", "//ortools/util:saturated_arithmetic", @@ -4018,6 +4155,7 @@ cc_library( ":synchronization", ":timetable", ":util", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:saturated_arithmetic", "//ortools/util:strong_integers", @@ -4077,6 +4215,7 @@ cc_library( "@abseil-cpp//absl/container:btree", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/types:span", ], @@ -4112,6 +4251,7 @@ cc_library( ":sat_solver", ":util", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/util:strong_integers", "@abseil-cpp//absl/container:flat_hash_map", "@abseil-cpp//absl/log:check", @@ -4156,6 +4296,7 @@ cc_library( ":util", "//ortools/base:log_severity", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/base/macros:buildenv", "//ortools/graph_base:connected_components", "//ortools/util:adaptative_parameter_value", @@ -4173,7 +4314,6 @@ cc_library( "@abseil-cpp//absl/log", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/log:vlog_is_on", - "@abseil-cpp//absl/meta:type_traits", "@abseil-cpp//absl/random:bit_gen_ref", "@abseil-cpp//absl/random:distributions", "@abseil-cpp//absl/strings", @@ -4230,6 +4370,7 @@ cc_library( ":synchronization", ":util", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/glop:parameters_cc_proto", "//ortools/glop:revised_simplex", "//ortools/glop:status", @@ -4345,6 +4486,7 @@ cc_library( ":sat_base", "//ortools/base:hash", "//ortools/base:strong_vector", + "//ortools/base:types", "//ortools/util:strong_integers", "//ortools/util:time_limit", "@abseil-cpp//absl/container:flat_hash_set", @@ -4484,6 +4626,7 @@ cc_library( ":cp_model_cc_proto", ":cp_model_utils", "//ortools/base:stl_util", + "//ortools/base:types", "//ortools/util:filelineiter", "@abseil-cpp//absl/base:core_headers", "@abseil-cpp//absl/container:flat_hash_map", @@ -4553,6 +4696,7 @@ cc_binary( "//ortools/base", "//ortools/base:file", "//ortools/base:path", + "//ortools/base:types", "//ortools/base:zipfile", "//ortools/util:file_util", "//ortools/util:logging", @@ -4626,6 +4770,7 @@ cc_library( "//ortools/algorithms:sparse_permutation", "//ortools/base:hash", "//ortools/base:log_severity", + "//ortools/base:types", "//ortools/graph_base:graph", "//ortools/util:affine_relation", "//ortools/util:logging", @@ -4745,6 +4890,7 @@ cc_library( hdrs = ["inclusion.h"], deps = [ ":util", + "//ortools/base:types", "//ortools/util:bitset", "//ortools/util:time_limit", "@abseil-cpp//absl/log:check", @@ -4759,6 +4905,7 @@ cc_library( deps = [ ":util", "//ortools/base:mathutil", + "//ortools/base:types", "@abseil-cpp//absl/log:check", "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/types:span", @@ -4812,6 +4959,7 @@ cc_library( ":sat_solver", ":synchronization", ":util", + "//ortools/base:types", "//ortools/util:running_stat", "//ortools/util:strong_integers", "//ortools/util:time_limit", @@ -4855,6 +5003,7 @@ cc_test( "//ortools/base:gmock_main", "//ortools/base:log_severity", "//ortools/base:mathutil", + "//ortools/base:types", "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/random", "@abseil-cpp//absl/strings", diff --git a/ortools/sat/circuit.cc b/ortools/sat/circuit.cc index 6d2c986cddd..a74d2de4aa6 100644 --- a/ortools/sat/circuit.cc +++ b/ortools/sat/circuit.cc @@ -13,13 +13,13 @@ #include "ortools/sat/circuit.h" -#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/types/span.h" #include "ortools/graph_base/strongly_connected_components.h" #include "ortools/sat/all_different.h" @@ -30,6 +30,7 @@ #include "ortools/sat/pb_constraint.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_solver.h" +#include "ortools/sat/util.h" #include "ortools/util/strong_integers.h" namespace operations_research { @@ -49,14 +50,14 @@ CircuitPropagator::CircuitPropagator( prev_.resize(num_nodes_, -1); next_literal_.resize(num_nodes_); must_be_in_cycle_.resize(num_nodes_); + + const int num_arcs = tails.size(); absl::flat_hash_map literal_to_watch_index; + literal_to_watch_index.reserve(num_arcs); // Temporary data to fill watch_index_to_arcs_. - const int num_arcs = tails.size(); - std::vector keys; - std::vector values; - keys.reserve(num_arcs); - values.reserve(num_arcs); + CompactVectorVectorBuilder watch_index_to_arcs_builder; + watch_index_to_arcs_builder.ReserveNumItems(num_arcs); graph_.reserve(num_arcs); self_arcs_.resize(num_nodes_, kFalseLiteralIndex); @@ -96,18 +97,14 @@ CircuitPropagator::CircuitPropagator( // Tricky: For self-arc, we watch instead when the arc become false. const Literal watched_literal = tail == head ? literal.Negated() : literal; - const auto& it = literal_to_watch_index.find(watched_literal.Index()); - int watch_index = it != literal_to_watch_index.end() ? it->second : -1; - if (watch_index == -1) { - watch_index = watch_index_to_literal_.size(); - literal_to_watch_index[watched_literal.Index()] = watch_index; + const auto [it, inserted] = literal_to_watch_index.insert( + {watched_literal.Index(), literal_to_watch_index.size()}); + if (inserted) { watch_index_to_literal_.push_back(watched_literal); } - - keys.push_back(watch_index); - values.push_back({tail, head}); + watch_index_to_arcs_builder.Add(it->second, {tail, head}); } - watch_index_to_arcs_.ResetFromFlatMapping(keys, values); + watch_index_to_arcs_.ResetFromBuilder(watch_index_to_arcs_builder); for (int node = 0; node < num_nodes_; ++node) { if (self_arcs_[node] == kFalseLiteralIndex || diff --git a/ortools/sat/clause.cc b/ortools/sat/clause.cc index 177fed7458a..8f20ad136b8 100644 --- a/ortools/sat/clause.cc +++ b/ortools/sat/clause.cc @@ -1164,8 +1164,10 @@ bool BinaryImplicationGraph::FixLiteral(Literal true_literal, if (trail_->Assignment().LiteralIsTrue(true_literal)) return true; if (trail_->Assignment().LiteralIsFalse(true_literal)) { if (lrat_proof_handler_ != nullptr) { - std::vector unsat_proof = {proof.begin(), proof.end()}; + std::vector unsat_proof; + unsat_proof.reserve(proof.size() + 1); unsat_proof.push_back(ClausePtr(true_literal.Negated())); + unsat_proof.insert(unsat_proof.end(), proof.begin(), proof.end()); lrat_proof_handler_->AddInferredClause(ClausePtr::EmptyClausePtr(), unsat_proof); } @@ -2090,6 +2092,19 @@ class LratEquivalenceHelper { std::vector tmp_literals_; }; +void BinaryImplicationGraph::ExportAllEquivalences() { + if (!enable_sharing_) return; + if (add_binary_callback_ == nullptr) return; + for (BooleanVariable var(0); var < representative_of_.size(); ++var) { + const Literal lit(var, true); + const Literal rep = RepresentativeOf(lit); + if (lit != rep) { + add_binary_callback_(lit.Negated(), rep); + add_binary_callback_(lit, rep.Negated()); + } + } +} + bool BinaryImplicationGraph::DetectEquivalences(bool log_info) { // This was already called, and no new constraint where added. Note that new // fixed variable cannot create new equivalence, only new binary clauses do. @@ -2178,6 +2193,22 @@ bool BinaryImplicationGraph::DetectEquivalences(bool log_info) { for (int i = 1; i < component.size(); ++i) { const Literal literal = Literal(LiteralIndex(component[i])); if (!is_redundant_[literal]) { + // This allows to make sure the SharedClauseManager properly detect + // that we have an equivalence. It shouldn't share more than + // O(num_variables) binary clauses per worker. + // + // TODO(user): Alternatively, we could run DetectEquivalences() in the + // SharedClauseManager, that might be more robust to make sure we don't + // miss any. Not clear what is the best approach. + // + // If literal.Negated() == representative the model is unsat, which + // should be detected by the other test below. + if (enable_sharing_ && add_binary_callback_ != nullptr && + literal.Negated() != Literal(representative)) { + add_binary_callback_(literal.Negated(), Literal(representative)); + add_binary_callback_(literal, Literal(representative).Negated()); + } + ++num_new_redundant_literals; is_redundant_.Set(literal); } @@ -2923,7 +2954,7 @@ std::vector BinaryImplicationGraph::ExpandAtMostOneWithWeight( const double lp = use_weight ? expanded_lp_values[Literal(intersection[j]).NegatedIndex()] + - absl::Uniform(*random_, 0.0, 1e-4) + absl::Uniform(random_, 0.0, 1e-4) : can_be_included.size() - intersection[j].value(); if (index == -1 || lp > max_lp) { index = j; @@ -3039,7 +3070,7 @@ BinaryImplicationGraph::GenerateAtMostOnesWithLargeWeight( const double activity = current_value + expanded_lp_values[l.NegatedIndex()]; if (activity <= 1.01) continue; - const double v = activity + absl::Uniform(*random_, 0.0, 1e-4); + const double v = activity + absl::Uniform(random_, 0.0, 1e-4); if (best == kNoLiteralIndex || v > best_value) { best_value = v; best = l.NegatedIndex(); @@ -3076,7 +3107,7 @@ BinaryImplicationGraph::GenerateAtMostOnesWithLargeWeight( const int max_graph_size = 1024; if (fractional_literals.size() > max_graph_size) { std::shuffle(fractional_literals.begin(), fractional_literals.end(), - *random_); + random_); fractional_literals.resize(max_graph_size); } @@ -3398,7 +3429,7 @@ LiteralIndex BinaryImplicationGraph::RandomImpliedLiteral(Literal lhs) { const int size2 = implications_and_amos_[lhs].num_offsets(); if (size1 + size2 == 0) return kNoLiteralIndex; - const int choice = absl::Uniform(*random_, 0, size1 + size2); + const int choice = absl::Uniform(random_, 0, size1 + size2); if (choice < size1) { return implications_and_amos_[lhs].literals()[choice].Index(); } @@ -3406,12 +3437,12 @@ LiteralIndex BinaryImplicationGraph::RandomImpliedLiteral(Literal lhs) { const absl::Span amo = AtMostOne(implications_and_amos_[lhs].offsets()[choice - size1]); CHECK_GE(amo.size(), 2); - const int first_choice = absl::Uniform(*random_, 0, amo.size()); + const int first_choice = absl::Uniform(random_, 0, amo.size()); const Literal lit = amo[first_choice]; if (lit != lhs) return lit.NegatedIndex(); // We are unlucky and just picked the wrong literal: take a different one. - int next_choice = absl::Uniform(*random_, 0, amo.size() - 1); + int next_choice = absl::Uniform(random_, 0, amo.size() - 1); if (next_choice >= first_choice) { next_choice += 1; } @@ -3717,7 +3748,7 @@ BinaryImplicationGraph::BinaryImplicationGraph(Model* model) : SatPropagator("BinaryImplicationGraph"), stats_("BinaryImplicationGraph"), time_limit_(model->GetOrCreate()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), trail_(model->GetOrCreate()), lrat_proof_handler_(model->Mutable()), at_most_one_max_expansion_size_(model->GetOrCreate() diff --git a/ortools/sat/clause.h b/ortools/sat/clause.h index af248d1491c..33ebda032f7 100644 --- a/ortools/sat/clause.h +++ b/ortools/sat/clause.h @@ -579,6 +579,11 @@ class BinaryImplicationGraph : public SatPropagator { void SetAdditionCallback(std::function f) { add_binary_callback_ = f; } + + // Call the "add_binary_callback_" on all var <=> representative. + // See the call-site to see the usage of this. + void ExportAllEquivalences(); + // An at most one constraint of size n is a compact way to encode n * (n - 1) // implications. This must only be called at level zero. // @@ -632,7 +637,7 @@ class BinaryImplicationGraph : public SatPropagator { // TODO(user): Completely get rid of such literal instead? it might not be // reasonable code-wise to remap our literals in all of our constraints // though. - bool DetectEquivalences(bool log_info = false); + ABSL_MUST_USE_RESULT bool DetectEquivalences(bool log_info = false); // Returns true if DetectEquivalences() has been called and no new binary // clauses have been added since then. When this is true then there is no @@ -640,10 +645,12 @@ class BinaryImplicationGraph : public SatPropagator { // form a cycle with their representative). bool IsDag() const { return is_dag_; } - // One must call DetectEquivalences() first, this is CHECKed. // Returns a list so that if x => y, then x is after y. + // + // One must call DetectEquivalences() first, otherwise this can be empty. Note + // that it is usually okay to use an "old" topological order if + // DetectEquivalences() was not called on the last state. const std::vector& ReverseTopologicalOrder() const { - CHECK(is_dag_); return reverse_topological_order_; } @@ -684,8 +691,9 @@ class BinaryImplicationGraph : public SatPropagator { // // Returns false if the model is detected to be UNSAT (this needs to call // DetectEquivalences() if not already done). - bool TransformIntoMaxCliques(std::vector>* at_most_ones, - int64_t max_num_explored_nodes = 1e8); + ABSL_MUST_USE_RESULT bool TransformIntoMaxCliques( + std::vector>* at_most_ones, + int64_t max_num_explored_nodes = 1e8); // This is similar to TransformIntoMaxCliques() but we are just looking into // reducing the number of constraints. If two initial clique A and B can be @@ -976,7 +984,7 @@ class BinaryImplicationGraph : public SatPropagator { mutable StatsGroup stats_; TimeLimit* time_limit_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; Trail* trail_; LratProofHandler* lrat_proof_handler_ = nullptr; LratEquivalenceHelper* lrat_helper_ = nullptr; diff --git a/ortools/sat/clause_test.cc b/ortools/sat/clause_test.cc index 698e99ee1d0..f3e6122ddd3 100644 --- a/ortools/sat/clause_test.cc +++ b/ortools/sat/clause_test.cc @@ -259,7 +259,7 @@ TEST(BinaryImplicationGraphTest, BasicCliqueDetection) { for (const std::vector& at_most_one : at_most_ones) { EXPECT_TRUE(graph->AddAtMostOne(at_most_one)); } - graph->TransformIntoMaxCliques(&at_most_ones); + EXPECT_TRUE(graph->TransformIntoMaxCliques(&at_most_ones)); EXPECT_THAT(at_most_ones[0], LiteralsAre(+1, +2, +3)); EXPECT_TRUE(at_most_ones[1].empty()); EXPECT_TRUE(at_most_ones[2].empty()); @@ -278,7 +278,7 @@ TEST(BinaryImplicationGraphTest, CliqueDetectionAndDuplicates) { } // Here we do not change the clique. - graph->TransformIntoMaxCliques(&at_most_ones); + EXPECT_TRUE(graph->TransformIntoMaxCliques(&at_most_ones)); EXPECT_THAT(at_most_ones, ElementsAre(LiteralsAre(+1, +2), LiteralsAre(+2, +2))); @@ -422,7 +422,9 @@ void TryAmoEquivalences(absl::Span> cliques) { return; } } - graph->DetectEquivalences(); + + // This can be either false or true since this is a fuzzer test. + (void)graph->DetectEquivalences(); } } // namespace diff --git a/ortools/sat/constraint_violation.cc b/ortools/sat/constraint_violation.cc index 2dcf5269388..2c43e3ba4ca 100644 --- a/ortools/sat/constraint_violation.cc +++ b/ortools/sat/constraint_violation.cc @@ -16,9 +16,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -27,12 +27,16 @@ #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/mathutil.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/graph_base/strongly_connected_components.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer_base.h" #include "ortools/sat/util.h" #include "ortools/util/dense_set.h" #include "ortools/util/saturated_arithmetic.h" @@ -487,13 +491,46 @@ void LinearIncrementalEvaluator::UpdateScoreOnActivityChange( const int64_t* row_coeffs = &row_coeff_buffer_[data.linear_start]; num_ops_ += 2 * data.num_linear_entries; + const int64_t old_a_minus_new_a = + distances_[c] - domains_[c].Distance(new_activity); + // Computing general Domain distance is slow. // TODO(user): optimize even more for one sided constraints. // Note(user): I tried to factor the two usage of this, but it is slower. const Domain& rhs = domains_[c]; const int64_t rhs_min = rhs.Min(); const int64_t rhs_max = rhs.Max(); - const bool is_simple = rhs.NumIntervals() == 2; + const bool is_simple = rhs.NumIntervals() == 1; + + // Cover the common one-sided case linear <= rhs_max. + // This is how the objective is encoded, and it is important to handle this + // efficiently. + if (is_simple && min_range >= rhs_min) { + for (int k = 0; k < data.num_linear_entries; ++k) { + const int var = row_vars[k]; + const int64_t impact = row_coeffs[k] * jump_deltas[var]; + const int64_t old_b = + std::max(old_activity + impact - rhs_max, 0); + const int64_t new_b = + std::max(new_activity + impact - rhs_max, 0); + + // The old score was: + // weight * static_cast(old_b - old_a); + // the new score is + // weight * static_cast(new_b - new_a); so the diff is: + // weight * static_cast(new_b - new_a - old_b + old_a) + const int64_t diff = old_a_minus_new_a + new_b - old_b; + if (diff == 0) continue; + + // TODO(user): If a variable is at its lower (resp. upper) bound, then + // we know that the score will always move in the same direction, so we + // might skip the last_affected_variables_ update. + jump_scores[var] += weight * static_cast(diff); + last_affected_variables_.Set(var); + } + return; + } + const auto violation = [&rhs, rhs_min, rhs_max, is_simple](int64_t v) { if (v >= rhs_max) { return v - rhs_max; @@ -503,9 +540,6 @@ void LinearIncrementalEvaluator::UpdateScoreOnActivityChange( return is_simple ? int64_t{0} : rhs.Distance(v); } }; - - const int64_t old_a_minus_new_a = - distances_[c] - domains_[c].Distance(new_activity); for (int k = 0; k < data.num_linear_entries; ++k) { const int var = row_vars[k]; const int64_t impact = row_coeffs[k] * jump_deltas[var]; @@ -518,6 +552,7 @@ void LinearIncrementalEvaluator::UpdateScoreOnActivityChange( // weight * static_cast(new_b - new_a); so the diff is: // weight * static_cast(new_b - new_a - old_b + old_a) const int64_t diff = old_a_minus_new_a + new_b - old_b; + if (diff == 0) continue; // TODO(user): If a variable is at its lower (resp. upper) bound, then // we know that the score will always move in the same direction, so we @@ -758,7 +793,7 @@ std::vector LinearIncrementalEvaluator::SlopeBreakpoints( const int64_t slack_min = CapSub(domains_[c].Min(), activity); const int64_t slack_max = CapSub(domains_[c].Max(), activity); - if (slack_min != std::numeric_limits::min()) { + if (slack_min != kint64min) { const int64_t ceil_bp = MathUtil::CeilOfRatio(slack_min, coeff); if (ceil_bp != result.back() && var_domain.Contains(ceil_bp)) { result.push_back(ceil_bp); @@ -768,8 +803,7 @@ std::vector LinearIncrementalEvaluator::SlopeBreakpoints( result.push_back(floor_bp); } } - if (slack_min != slack_max && - slack_max != std::numeric_limits::min()) { + if (slack_min != slack_max && slack_max != kint64min) { const int64_t ceil_bp = MathUtil::CeilOfRatio(slack_max, coeff); if (ceil_bp != result.back() && var_domain.Contains(ceil_bp)) { result.push_back(ceil_bp); @@ -1034,7 +1068,7 @@ int64_t CompiledLinMaxConstraint::ComputeViolationWhenEnforced( absl::Span solution) { const int64_t target_value = ExprValue(ct_proto().lin_max().target(), solution); - int64_t max_of_expressions = std::numeric_limits::min(); + int64_t max_of_expressions = kint64min; for (const LinearExpressionProto& expr : ct_proto().lin_max().exprs()) { const int64_t expr_value = ExprValue(expr, solution); max_of_expressions = std::max(max_of_expressions, expr_value); @@ -1160,6 +1194,7 @@ std::vector CompiledNoOverlapWithTwoIntervals::UsedVariables( const CpModelProto& /*model_proto*/) const { std::vector result; + result.reserve(enforcements_.size() + 4); if (has_enforcement) { for (const int ref : enforcements_) result.push_back(PositiveRef(ref)); } @@ -1174,88 +1209,257 @@ CompiledNoOverlapWithTwoIntervals::UsedVariables( // ----- CompiledNoOverlap2dConstraint ----- -int64_t OverlapOfTwoIntervals(const ConstraintProto& interval1, - const ConstraintProto& interval2, - absl::Span solution) { - for (const int lit : interval1.enforcement_literal()) { - if (!LiteralValue(lit, solution)) return 0; - } - for (const int lit : interval2.enforcement_literal()) { - if (!LiteralValue(lit, solution)) return 0; - } - - const int64_t start1 = ExprValue(interval1.interval().start(), solution); - const int64_t end1 = ExprValue(interval1.interval().end(), solution); - - const int64_t start2 = ExprValue(interval2.interval().start(), solution); - const int64_t end2 = ExprValue(interval2.interval().end(), solution); - - if (start1 >= end2 || start2 >= end1) return 0; // Disjoint. +// This assumes the intervals overlap. +int64_t OverlapOfTwoIntervals(const IntegerValue start1, + const IntegerValue end1, + const IntegerValue start2, + const IntegerValue end2) { + DCHECK((start1 < end2) && (start2 < end1)); // We force a min cost of 1 to cover the case where a interval of size 0 is in // the middle of another interval. return std::max(std::min(std::min(end2 - start2, end1 - start1), std::min(end2 - start1, end1 - start2)), - int64_t{1}); + IntegerValue{1}) + .value(); +} + +int64_t NoOverlapMinRepairDistance(const IntegerValue start1, + const IntegerValue end1, + const IntegerValue start2, + const IntegerValue end2) { + return std::max(std::min(end2 - start1, end1 - start2), IntegerValue{0}) + .value(); +} + +int64_t NoOverlap2dViolation(const Rectangle& r1, const Rectangle& r2) { + // Fast track if one dimension do not overlap. + if (r1.x_min >= r2.x_max || r2.x_min >= r1.x_max) return 0; + if (r1.y_min >= r2.y_max || r2.y_min >= r1.y_max) return 0; + + // TODO(user): Experiment with + // violation += + // std::max(std::min( + // NoOverlapMinRepairDistance(rect_i.x_min, rect_i.x_max, + // rect_j.x_min, rect_j.x_max), + // NoOverlapMinRepairDistance(rect_i.y_min, rect_i.y_max, + // rect_j.y_min, rect_j.y_max)), + // int64_t{0}); + // Currently, the effect is unclear on 2d packing problems. + + return std::max( + std::min( + NoOverlapMinRepairDistance(r1.x_min, r1.x_max, r2.x_min, r2.x_max) * + OverlapOfTwoIntervals(r1.y_min, r1.y_max, r2.y_min, r2.y_max), + NoOverlapMinRepairDistance(r1.y_min, r1.y_max, r2.y_min, r2.y_max) * + OverlapOfTwoIntervals(r1.x_min, r1.x_max, r2.x_min, r2.x_max)), + int64_t{0}); } -int64_t NoOverlapMinRepairDistance(const ConstraintProto& interval1, - const ConstraintProto& interval2, - absl::Span solution) { - for (const int lit : interval1.enforcement_literal()) { - if (!LiteralValue(lit, solution)) return 0; - } - for (const int lit : interval2.enforcement_literal()) { - if (!LiteralValue(lit, solution)) return 0; +CompiledNoOverlap2dConstraint::CompiledNoOverlap2dConstraint( + const ConstraintProto& ct_proto, const CpModelProto& cp_model) + : CompiledConstraintWithProto(ct_proto), cp_model_(cp_model) { + std::vector> var_to_boxes_pairs; + for (int i = 0; i < ct_proto.no_overlap_2d().x_intervals_size(); ++i) { + for (const int interval_idx : {ct_proto.no_overlap_2d().x_intervals(i), + ct_proto.no_overlap_2d().y_intervals(i)}) { + const ConstraintProto& interval = cp_model_.constraints(interval_idx); + for (const int lit : interval.enforcement_literal()) { + var_to_boxes_pairs.push_back({PositiveRef(lit), i}); + } + for (const LinearExpressionProto& expr : + {interval.interval().start(), interval.interval().end(), + interval.interval().size()}) { + for (const int var : expr.vars()) { + var_to_boxes_pairs.push_back({var, i}); + } + } + } } + gtl::STLSortAndRemoveDuplicates(&var_to_boxes_pairs); + var_to_boxes_.ResetFromPairs(var_to_boxes_pairs, cp_model_.variables_size()); + rectangles_.resize(ct_proto.no_overlap_2d().x_intervals_size()); + box_is_active_.assign(ct_proto.no_overlap_2d().x_intervals_size(), false); +} - const int64_t start1 = ExprValue(interval1.interval().start(), solution); - const int64_t end1 = ExprValue(interval1.interval().end(), solution); +Rectangle CompiledNoOverlap2dConstraint::ComputeRectangle( + int box_index, absl::Span solution) const { + const ConstraintProto& x_i = + cp_model_.constraints(ct_proto().no_overlap_2d().x_intervals(box_index)); + const ConstraintProto& y_i = + cp_model_.constraints(ct_proto().no_overlap_2d().y_intervals(box_index)); + return Rectangle{.x_min = ExprValue(x_i.interval().start(), solution), + .x_max = ExprValue(x_i.interval().end(), solution), + .y_min = ExprValue(y_i.interval().start(), solution), + .y_max = ExprValue(y_i.interval().end(), solution)}; +} - const int64_t start2 = ExprValue(interval2.interval().start(), solution); - const int64_t end2 = ExprValue(interval2.interval().end(), solution); +bool CompiledNoOverlap2dConstraint::IsRectangleActive( + int box_index, absl::Span solution) const { + const ConstraintProto& x_i = + cp_model_.constraints(ct_proto().no_overlap_2d().x_intervals(box_index)); + const ConstraintProto& y_i = + cp_model_.constraints(ct_proto().no_overlap_2d().y_intervals(box_index)); - return std::max(std::min(end2 - start1, end1 - start2), int64_t{0}); + for (const auto& enf : + {x_i.enforcement_literal(), y_i.enforcement_literal()}) { + for (const int lit : enf) { + if (!LiteralValue(lit, solution)) return false; + } + } + return true; } -CompiledNoOverlap2dConstraint::CompiledNoOverlap2dConstraint( - const ConstraintProto& ct_proto, const CpModelProto& cp_model) - : CompiledConstraintWithProto(ct_proto), cp_model_(cp_model) {} - int64_t CompiledNoOverlap2dConstraint::ComputeViolationWhenEnforced( absl::Span solution) { DCHECK_GE(ct_proto().no_overlap_2d().x_intervals_size(), 2); const int size = ct_proto().no_overlap_2d().x_intervals_size(); + // Recompute our "cache" first. + std::vector rectangles; + std::vector rectangles_indices; + rectangles.reserve(size); + for (int i = 0; i < size; ++i) { + if (!IsRectangleActive(i, solution)) { + box_is_active_[i] = false; + continue; + } + box_is_active_[i] = true; + rectangles_[i] = ComputeRectangle(i, solution); + + rectangles.push_back(rectangles_[i]); + rectangles_indices.push_back(i); + } + RecomputeActiveBoxes(); + + // Since computing the violation is O(N^2), we will first compute in O(N log + // N) the set of rectangles that overlap with at least one other rectangle. + absl::flat_hash_set overlapping_set; + for (const auto& [i, j] : FindPartialRectangleIntersections(rectangles)) { + overlapping_set.insert(i); + overlapping_set.insert(j); + } + + int new_size = 0; + for (int i = 0; i < rectangles.size(); ++i) { + if (overlapping_set.contains(i)) { + rectangles_indices[new_size] = rectangles_indices[i]; + rectangles[new_size] = rectangles[i]; + new_size++; + } + } + rectangles.resize(new_size); + rectangles_indices.resize(new_size); + int64_t violation = 0; - for (int i = 0; i + 1 < size; ++i) { - const ConstraintProto& x_i = - cp_model_.constraints(ct_proto().no_overlap_2d().x_intervals(i)); - const ConstraintProto& y_i = - cp_model_.constraints(ct_proto().no_overlap_2d().y_intervals(i)); - for (int j = i + 1; j < size; ++j) { - const ConstraintProto& x_j = - cp_model_.constraints(ct_proto().no_overlap_2d().x_intervals(j)); - const ConstraintProto& y_j = - cp_model_.constraints(ct_proto().no_overlap_2d().y_intervals(j)); - - // TODO(user): Experiment with - // violation += - // std::max(std::min(NoOverlapMinRepairDistance(x_i, x_j, solution), - // NoOverlapMinRepairDistance(y_i, y_j, solution)), - // int64_t{0}); - // Currently, the effect is unclear on 2d packing problems. - violation += - std::max(std::min(NoOverlapMinRepairDistance(x_i, x_j, solution) * - OverlapOfTwoIntervals(y_i, y_j, solution), - NoOverlapMinRepairDistance(y_i, y_j, solution) * - OverlapOfTwoIntervals(x_i, x_j, solution)), - int64_t{0}); + for (int i = 0; i < rectangles.size(); ++i) { + const Rectangle& rect_i = rectangles[i]; + for (int j = i + 1; j < rectangles.size(); ++j) { + const Rectangle& rect_j = rectangles[j]; + violation += NoOverlap2dViolation(rect_i, rect_j); } } return violation; } +int64_t CompiledNoOverlap2dConstraint::ViolationDeltaWhenEnforced( + int var, int64_t /*old_value*/, + absl::Span solution_with_new_value) { + // Note that changing a variable can change more than one rectangle. + absl::Span changed_boxes = var_to_boxes_[var]; + int64_t delta = 0; + + // Fast track if a single box changed. + if (changed_boxes.size() == 1) { + const int box_index = changed_boxes[0]; + const Rectangle& old_rect = rectangles_[box_index]; + const bool old_box_is_active = box_is_active_[box_index]; + const bool new_box_is_active = + IsRectangleActive(box_index, solution_with_new_value); + const Rectangle new_rect = + ComputeRectangle(box_index, solution_with_new_value); + + for (const int j : active_boxes_) { + if (j == box_index) continue; + + // The box j is active and not changed. + const Rectangle& rect_j = rectangles_[j]; + + // Remove the old contribution. + if (old_box_is_active) { + delta -= NoOverlap2dViolation(old_rect, rect_j); + } + + // Add new contribution. + if (new_box_is_active) { + delta += NoOverlap2dViolation(new_rect, rect_j); + } + } + return delta; + } + + // Update all the pairs for which at least one of the boxes is + // changed. + const absl::flat_hash_set changed_boxes_set(changed_boxes.begin(), + changed_boxes.end()); + for (const int box_index : changed_boxes) { + const Rectangle& old_rect = rectangles_[box_index]; + const bool old_box_is_active = box_is_active_[box_index]; + + const bool new_box_is_active = + IsRectangleActive(box_index, solution_with_new_value); + const Rectangle new_rect = + ComputeRectangle(box_index, solution_with_new_value); + + Rectangle new_rect_j_storage; + for (int j = 0; j < rectangles_.size(); ++j) { + if (j == box_index) continue; + + // Avoid double-counting two boxes that are both changed. + const bool changed = changed_boxes_set.contains(j); + if (changed && j < box_index) continue; + + const Rectangle* new_rect_j = nullptr; + bool new_rect_j_is_active = false; + const bool old_rect_j_is_active = box_is_active_[j]; + const Rectangle& old_rect_j = rectangles_[j]; + if (changed) { + // The box j is also changed, we need to update it. + new_rect_j_storage = ComputeRectangle(j, solution_with_new_value); + new_rect_j = &new_rect_j_storage; + new_rect_j_is_active = IsRectangleActive(j, solution_with_new_value); + } else { + // The box j is not changed, so new and old are the same. + new_rect_j_is_active = box_is_active_[j]; + new_rect_j = &old_rect_j; + } + // The pair was contributing to the violation before only if both + // boxes were active. Remove the old contribution. + if (old_box_is_active && old_rect_j_is_active) { + delta -= NoOverlap2dViolation(old_rect, old_rect_j); + } + // The new pair contributes to the violation if both boxes are active. + if (new_box_is_active && new_rect_j_is_active) { + delta += NoOverlap2dViolation(new_rect, *new_rect_j); + } + } + } + return delta; +} + +void CompiledNoOverlap2dConstraint::PerformMove( + int var, int64_t old_value, + absl::Span solution_with_new_value) { + violation_ += ViolationDelta(var, old_value, solution_with_new_value); + for (const int box_index : var_to_boxes_[var]) { + box_is_active_[box_index] = + IsRectangleActive(box_index, solution_with_new_value); + rectangles_[box_index] = + ComputeRectangle(box_index, solution_with_new_value); + } + RecomputeActiveBoxes(); +} + template int64_t CompiledNoOverlap2dWithTwoBoxes::ViolationDelta( int /*var*/, int64_t /*old_value*/, absl::Span solution) { @@ -1532,7 +1736,6 @@ void AddCircuitFlowConstraints(LinearIncrementalEvaluator& linear_evaluator, LsEvaluator::LsEvaluator(const CpModelProto& cp_model, const SatParameters& params, TimeLimit* time_limit) : cp_model_(cp_model), params_(params), time_limit_(time_limit) { - var_to_constraints_.resize(cp_model_.variables_size()); var_to_dtime_estimate_.resize(cp_model_.variables_size()); jump_value_optimal_.resize(cp_model_.variables_size(), true); num_violated_constraint_per_var_ignoring_objective_.assign( @@ -1551,7 +1754,6 @@ LsEvaluator::LsEvaluator( absl::Span additional_constraints, TimeLimit* time_limit) : cp_model_(cp_model), params_(params), time_limit_(time_limit) { - var_to_constraints_.resize(cp_model_.variables_size()); var_to_dtime_estimate_.resize(cp_model_.variables_size()); jump_value_optimal_.resize(cp_model_.variables_size(), true); num_violated_constraint_per_var_ignoring_objective_.assign( @@ -1562,28 +1764,27 @@ LsEvaluator::LsEvaluator( } void LsEvaluator::BuildVarConstraintGraph() { - // Clear the var <-> constraint graph. - for (std::vector& ct_indices : var_to_constraints_) ct_indices.clear(); - constraint_to_vars_.resize(constraints_.size()); + constraint_to_vars_.clear(); // Build the var <-> constraint graph. + CompactVectorVectorBuilder var_to_constraints_builder; for (int ct_index = 0; ct_index < constraints_.size(); ++ct_index) { - constraint_to_vars_[ct_index] = - constraints_[ct_index]->UsedVariables(cp_model_); + constraint_to_vars_.Add(constraints_[ct_index]->UsedVariables(cp_model_)); const double dtime = 1e-8 * constraint_to_vars_[ct_index].size(); for (const int var : constraint_to_vars_[ct_index]) { - var_to_constraints_[var].push_back(ct_index); + var_to_constraints_builder.Add(var, ct_index); var_to_dtime_estimate_[var] += dtime; } } + var_to_constraints_.ResetFromBuilder(var_to_constraints_builder, + cp_model_.variables_size()); - // Remove duplicates. - for (std::vector& constraints : var_to_constraints_) { - gtl::STLSortAndRemoveDuplicates(&constraints); + for (int i = 0; i < var_to_constraints_.size(); ++i) { + var_to_constraints_.SortAndRemoveDuplicateValues(i); } - for (std::vector& vars : constraint_to_vars_) { - gtl::STLSortAndRemoveDuplicates(&vars); + for (int i = 0; i < constraint_to_vars_.size(); ++i) { + constraint_to_vars_.SortAndRemoveDuplicateValues(i); } // Scan the model to decide if a variable is linked to a convex evaluation. @@ -2062,6 +2263,14 @@ bool LsEvaluator::IsViolated(int c) const { } } +std::string LsEvaluator::ConstraintDebugString(int c) const { + if (c < linear_evaluator_.num_constraints()) { + return absl::StrCat("Linear constraint ", c); + } else { + return constraints_[c - linear_evaluator_.num_constraints()]->DebugString(); + } +} + double LsEvaluator::WeightedViolation(absl::Span weights) const { DCHECK_EQ(weights.size(), NumEvaluatorConstraints()); double result = linear_evaluator_.WeightedViolation(weights); @@ -2192,7 +2401,7 @@ int64_t CompiledReservoirConstraint::BuildProfileAndReturnViolation( int64_t overload = 0; int64_t current_load = 0; - int64_t previous_time = std::numeric_limits::min(); + int64_t previous_time = kint64min; for (int i = 0; i < profile_.size(); ++i) { // At this point, current_load is the load at previous_time. const int64_t time = profile_[i].time; @@ -2253,7 +2462,7 @@ int64_t CompiledReservoirConstraint::IncrementalViolation( // Similar algo, but we scan the two vectors at once. int64_t overload = 0; int64_t current_load = 0; - int64_t previous_time = std::numeric_limits::min(); + int64_t previous_time = kint64min; // TODO(user): This code is the hotspot for our local search on cumulative. // It can probably be slightly improved. We might also be able to abort early diff --git a/ortools/sat/constraint_violation.h b/ortools/sat/constraint_violation.h index 0dde57b3eb4..a9a9123f9ef 100644 --- a/ortools/sat/constraint_violation.h +++ b/ortools/sat/constraint_violation.h @@ -18,14 +18,18 @@ #include #include #include +#include #include #include #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/stl_util.h" +#include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/diffn_util.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/util.h" #include "ortools/util/bitset.h" @@ -267,6 +271,8 @@ class CompiledConstraint { // The cached violation of this constraint. int64_t violation() const { return violation_; } + virtual std::string DebugString() const = 0; + protected: // Computes the violation of a constraint. // @@ -296,6 +302,11 @@ class CompiledConstraintWithProto : public CompiledConstraint { // This just returns the variables used by the stored ct_proto_. std::vector UsedVariables(const CpModelProto& model_proto) const final; + std::string DebugString() const override { + return absl::StrCat("CompiledConstraintWithProto: ", + ProtobufShortDebugString(ct_proto_)); + } + protected: // Computes the violation of a constraint when it is enforced. virtual int64_t ComputeViolationWhenEnforced( @@ -383,6 +394,7 @@ class LsEvaluator { // size as NumEvaluatorConstraints(). int64_t Violation(int c) const; bool IsViolated(int c) const; + std::string ConstraintDebugString(int c) const; double WeightedViolation(absl::Span weights) const; // Computes the delta in weighted violation if solution[var] += delta. @@ -459,9 +471,9 @@ class LsEvaluator { CpModelProto expanded_constraints_; LinearIncrementalEvaluator linear_evaluator_; std::vector> constraints_; - std::vector> var_to_constraints_; + CompactVectorVector var_to_constraints_; std::vector var_to_dtime_estimate_; - std::vector> constraint_to_vars_; + CompactVectorVector constraint_to_vars_; std::vector jump_value_optimal_; TimeLimit* time_limit_; @@ -578,6 +590,12 @@ struct ViewOfAffineLinearExpressionProto { if (coeff != 0) result.push_back(var); } + template + friend void AbslStringify(Sink& sink, + const ViewOfAffineLinearExpressionProto& expr) { + absl::Format(&sink, "%d*I%d+%d", expr.coeff, expr.var, expr.offset); + } + int var = 0; int64_t coeff = 0; int64_t offset = 0; @@ -615,6 +633,12 @@ class CompiledNoOverlapWithTwoIntervals : public CompiledConstraint { ~CompiledNoOverlapWithTwoIntervals() final = default; + std::string DebugString() const final { + return absl::StrCat("CompiledNoOverlapWithTwoIntervals: (", + interval1_.start, " - ", interval1_.end, "), (", + interval2_.start, " - ", interval2_.end, ")"); + } + int64_t ComputeViolation(absl::Span solution) final { // Optimization hack: If we create a ComputeViolationInternal() that we call // from here and in ViolationDelta(), then the later is not inlined below in @@ -645,8 +669,34 @@ class CompiledNoOverlap2dConstraint : public CompiledConstraintWithProto { int64_t ComputeViolationWhenEnforced( absl::Span solution) override; + int64_t ViolationDeltaWhenEnforced( + int var, int64_t old_value, + absl::Span solution_with_new_value) override; + + void PerformMove(int var, int64_t old_value, + absl::Span solution_with_new_value) override; + private: + Rectangle ComputeRectangle(int box_index, + absl::Span solution) const; + bool IsRectangleActive(int box_index, + absl::Span solution) const; + + void RecomputeActiveBoxes() { + active_boxes_.clear(); + active_boxes_.reserve(box_is_active_.size()); + for (int i = 0; i < box_is_active_.size(); ++i) { + if (box_is_active_[i]) { + active_boxes_.push_back(i); + } + } + } + const CpModelProto& cp_model_; + CompactVectorVector var_to_boxes_; + std::vector rectangles_; + std::vector box_is_active_; + std::vector active_boxes_; }; template @@ -691,6 +741,13 @@ class CompiledNoOverlap2dWithTwoBoxes : public CompiledConstraint { ~CompiledNoOverlap2dWithTwoBoxes() final = default; + std::string DebugString() const final { + return absl::StrCat("CompiledNoOverlap2dWithTwoBoxes: (", box1_.x_min, + " - ", box1_.x_max, ")x(", box1_.y_min, " - ", + box1_.y_max, "), (", box2_.x_min, " - ", box2_.x_max, + ")x(", box2_.y_min, " - ", box2_.y_max, ")"); + } + int64_t ComputeViolation(absl::Span solution) final { // Optimization hack: If we create a ComputeViolationInternal() that we call // from here and in ViolationDelta(), then the later is not inlined below in @@ -758,6 +815,10 @@ class CompiledReservoirConstraint : public CompiledConstraint { std::vector UsedVariables(const CpModelProto& model_proto) const final; + std::string DebugString() const final { + return absl::StrCat("CompiledReservoirConstrain, capacity=", capacity_); + } + private: // This works in O(n log n). int64_t BuildProfileAndReturnViolation(absl::Span solution); diff --git a/ortools/sat/container.h b/ortools/sat/container.h index 11b612b3c42..0985b676255 100644 --- a/ortools/sat/container.h +++ b/ortools/sat/container.h @@ -27,6 +27,7 @@ #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/sat/sat_base.h" namespace operations_research { @@ -182,7 +183,7 @@ class LiteralsOrOffsets { void GrowCapacity() { // TODO(user): crash later. // For now, we do if we use more than 2 GB per LiteralOrOffsets. - CHECK_LE(capacity_, std::numeric_limits::max() / 2); + CHECK_LE(capacity_, kuint32max / 2); const uint32_t new_capacity = static_cast(1.3 * static_cast(capacity_)); CHECK_GT(new_capacity, kInlineElements); diff --git a/ortools/sat/cp_model.cc b/ortools/sat/cp_model.cc index ac4a97a2d7a..c837f530784 100644 --- a/ortools/sat/cp_model.cc +++ b/ortools/sat/cp_model.cc @@ -15,7 +15,6 @@ #include #include -#include #include #include #include @@ -26,6 +25,7 @@ #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/util/sorted_interval_list.h" @@ -972,7 +972,7 @@ Constraint CpModelBuilder::AddGreaterOrEqual(const LinearExpr& left, FillLinearTerms(left, right, proto->mutable_linear()); const int64_t rhs = right.constant() - left.constant(); proto->mutable_linear()->add_domain(rhs); - proto->mutable_linear()->add_domain(std::numeric_limits::max()); + proto->mutable_linear()->add_domain(kint64max); return Constraint(proto); } @@ -981,7 +981,7 @@ Constraint CpModelBuilder::AddLessOrEqual(const LinearExpr& left, ConstraintProto* const proto = cp_model_.add_constraints(); FillLinearTerms(left, right, proto->mutable_linear()); const int64_t rhs = right.constant() - left.constant(); - proto->mutable_linear()->add_domain(std::numeric_limits::min()); + proto->mutable_linear()->add_domain(kint64min); proto->mutable_linear()->add_domain(rhs); return Constraint(proto); } @@ -992,7 +992,7 @@ Constraint CpModelBuilder::AddGreaterThan(const LinearExpr& left, FillLinearTerms(left, right, proto->mutable_linear()); const int64_t rhs = right.constant() - left.constant(); proto->mutable_linear()->add_domain(rhs + 1); - proto->mutable_linear()->add_domain(std::numeric_limits::max()); + proto->mutable_linear()->add_domain(kint64max); return Constraint(proto); } @@ -1001,7 +1001,7 @@ Constraint CpModelBuilder::AddLessThan(const LinearExpr& left, ConstraintProto* const proto = cp_model_.add_constraints(); FillLinearTerms(left, right, proto->mutable_linear()); const int64_t rhs = right.constant() - left.constant(); - proto->mutable_linear()->add_domain(std::numeric_limits::min()); + proto->mutable_linear()->add_domain(kint64min); proto->mutable_linear()->add_domain(rhs - 1); return Constraint(proto); } @@ -1028,10 +1028,10 @@ Constraint CpModelBuilder::AddNotEqual(const LinearExpr& left, ConstraintProto* const proto = cp_model_.add_constraints(); FillLinearTerms(left, right, proto->mutable_linear()); const int64_t rhs = right.constant() - left.constant(); - proto->mutable_linear()->add_domain(std::numeric_limits::min()); + proto->mutable_linear()->add_domain(kint64min); proto->mutable_linear()->add_domain(rhs - 1); proto->mutable_linear()->add_domain(rhs + 1); - proto->mutable_linear()->add_domain(std::numeric_limits::max()); + proto->mutable_linear()->add_domain(kint64max); return Constraint(proto); } diff --git a/ortools/sat/cp_model.h b/ortools/sat/cp_model.h index 9e30ed7d44b..3e258baef36 100644 --- a/ortools/sat/cp_model.h +++ b/ortools/sat/cp_model.h @@ -50,6 +50,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" // IWYU pragma: export. #include "ortools/sat/cp_model_utils.h" @@ -125,7 +126,7 @@ class BoolVar { BoolVar(int index, CpModelBuilder* builder); CpModelBuilder* builder_ = nullptr; - int index_ = std::numeric_limits::min(); + int index_ = kint32min; }; std::ostream& operator<<(std::ostream& os, const BoolVar& var); @@ -198,7 +199,7 @@ class IntVar { IntVar(int index, CpModelBuilder* builder); CpModelBuilder* builder_ = nullptr; - int index_ = std::numeric_limits::min(); + int index_ = kint32min; }; std::ostream& operator<<(std::ostream& os, const IntVar& var); @@ -495,7 +496,7 @@ class IntervalVar { IntervalVar(int index, CpModelBuilder* builder); CpModelBuilder* builder_ = nullptr; - int index_ = std::numeric_limits::min(); + int index_ = kint32min; }; std::ostream& operator<<(std::ostream& os, const IntervalVar& var); diff --git a/ortools/sat/cp_model.proto b/ortools/sat/cp_model.proto index d37c5051dee..b34d68f4f8c 100644 --- a/ortools/sat/cp_model.proto +++ b/ortools/sat/cp_model.proto @@ -273,11 +273,13 @@ message TableConstraintProto { bool negated = 3; } -// The two arrays of variable each represent a function, the second is the -// inverse of the first: f_direct[i] == j <=> f_inverse[j] == i. +// The two arrays of expressions each represent a function, the second is the +// inverse of the first: f_expr_direct[i] == j <=> f_expr_inverse[j] == i. message InverseConstraintProto { - repeated int32 f_direct = 1; - repeated int32 f_inverse = 2; + repeated int32 f_direct = 1; // Legacy field. + repeated int32 f_inverse = 2; // Legacy field. + repeated LinearExpressionProto f_expr_direct = 3; + repeated LinearExpressionProto f_expr_inverse = 4; } // This constraint forces a sequence of expressions to be accepted by an diff --git a/ortools/sat/cp_model_checker.cc b/ortools/sat/cp_model_checker.cc index 3748be2715a..f384598ae87 100644 --- a/ortools/sat/cp_model_checker.cc +++ b/ortools/sat/cp_model_checker.cc @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -31,6 +30,7 @@ #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" @@ -61,13 +61,18 @@ namespace { template bool DomainInProtoIsValid(const ProtoWithDomain& proto) { + if (proto.domain().empty()) return true; if (proto.domain().size() % 2) return false; - std::vector domain; - for (int i = 0; i < proto.domain_size(); i += 2) { + if (proto.domain(0) > proto.domain(1)) return false; + for (int i = 2; i < proto.domain_size(); i += 2) { if (proto.domain(i) > proto.domain(i + 1)) return false; - domain.push_back({proto.domain(i), proto.domain(i + 1)}); + + // Consecutive intervals must be non-adjacent. + // We test it in two steps so the +1 cannot overflow. + if (proto.domain(i - 1) >= proto.domain(i)) return false; + if (proto.domain(i - 1) + 1 >= proto.domain(i)) return false; } - return IntervalsAreSortedAndNonAdjacent(domain); + return true; } bool VariableReferenceIsValid(const CpModelProto& model, int reference) { @@ -112,8 +117,7 @@ std::string ValidateIntegerVariable(const CpModelProto& model, int v) { // the domain must fall in [-kint64max / 2, kint64max / 2]. const int64_t lb = proto.domain(0); const int64_t ub = proto.domain(proto.domain_size() - 1); - if (lb < -std::numeric_limits::max() / 2 || - ub > std::numeric_limits::max() / 2) { + if (lb < -kint64max / 2 || ub > kint64max / 2) { return absl::StrCat( "var #", v, " domain do not fall in [-kint64max / 2, kint64max / 2]. ", ProtobufShortDebugString(proto)); @@ -121,7 +125,7 @@ std::string ValidateIntegerVariable(const CpModelProto& model, int v) { // We do compute ub - lb in some place in the code and do not want to deal // with overflow everywhere. This seems like a reasonable precondition anyway. - if (lb < 0 && lb + std::numeric_limits::max() < ub) { + if (lb < 0 && lb + kint64max < ub) { return absl::StrCat( "var #", v, " has a domain that is too large, i.e. |UB - LB| overflow an int64_t: ", @@ -269,17 +273,17 @@ std::string ValidateLinearExpression(const CpModelProto& model, return absl::StrCat("coeffs_size() != vars_size() in linear expression: ", ProtobufShortDebugString(expr)); } - if (PossibleIntegerOverflow(model, expr.vars(), expr.coeffs(), - expr.offset())) { - return absl::StrCat("Possible overflow in linear expression: ", - ProtobufShortDebugString(expr)); - } for (const int var : expr.vars()) { if (!RefIsPositive(var)) { return absl::StrCat("Invalid negated variable in linear expression: ", ProtobufShortDebugString(expr)); } } + if (PossibleIntegerOverflow(model, expr.vars(), expr.coeffs(), + expr.offset())) { + return absl::StrCat("Possible overflow in linear expression: ", + ProtobufShortDebugString(expr)); + } return ""; } @@ -357,17 +361,15 @@ std::string ValidateIntProdConstraint(const CpModelProto& model, product_domain.ContinuousMultiplicationBy({min_expr, max_expr}); } - if (product_domain.Max() <= -std ::numeric_limits::max() || - product_domain.Min() >= std::numeric_limits::max()) { + if (product_domain.Max() <= -kint64max || product_domain.Min() >= kint64max) { return absl::StrCat("integer overflow in constraint: ", ProtobufShortDebugString(ct)); } // We need to expand the product when its arity is > 2. In that case, we must // be strict with overflows. - if (ct.int_prod().exprs_size() > 2 && - (product_domain.Max() >= std ::numeric_limits::max() || - product_domain.Min() <= -std::numeric_limits::max())) { + if (ct.int_prod().exprs_size() > 2 && (product_domain.Max() >= kint64max || + product_domain.Min() <= -kint64max)) { return absl::StrCat("Potential integer overflow in constraint: ", ProtobufShortDebugString(ct)); } @@ -443,6 +445,15 @@ std::string ValidateElementConstraint(const CpModelProto& model, "since the index will be out of bounds."; } + if (in_legacy_format) { + if (!VariableIndexIsValid(model, element.index()) || + !VariableIndexIsValid(model, element.target())) { + return absl::StrCat( + "Element constraint index and target must valid variables: ", + ProtobufShortDebugString(ct)); + } + } + // We need to be able to manipulate expression like "target - var" without // integer overflow. if (!element.vars().empty()) { @@ -467,15 +478,6 @@ std::string ValidateElementConstraint(const CpModelProto& model, } } - if (in_legacy_format) { - if (!VariableIndexIsValid(model, element.index()) || - !VariableIndexIsValid(model, element.target())) { - return absl::StrCat( - "Element constraint index and target must valid variables: ", - ProtobufShortDebugString(ct)); - } - } - if (in_linear_format) { RETURN_IF_NOT_EMPTY( ValidateLinearExpression(model, element.linear_index())); @@ -502,19 +504,36 @@ std::string ValidateElementConstraint(const CpModelProto& model, std::string ValidateInverseConstraint(const CpModelProto& model, const ConstraintProto& ct) { - if (ct.inverse().f_direct().size() != ct.inverse().f_inverse().size()) { + const InverseConstraintProto& inverse = ct.inverse(); + if ((ct.inverse().f_direct().size() != ct.inverse().f_inverse().size()) || + ct.inverse().f_expr_direct().size() != + ct.inverse().f_expr_inverse().size()) { return absl::StrCat("Non-matching fields size in inverse: ", ProtobufShortDebugString(ct)); } - const InverseConstraintProto& inverse = ct.inverse(); - for (const auto* vars : {&inverse.f_direct(), &inverse.f_inverse()}) { - for (const int var : *vars) { - if (!VariableIndexIsValid(model, var)) { - return absl::StrCat("Invalid variable index in inverse constraint: ", - var); + if (!inverse.f_inverse().empty() && !inverse.f_expr_inverse().empty()) { + return absl::StrCat( + "Inconsistent inverse with both legacy and new format defined: ", + ProtobufShortDebugString(ct)); + } + if (!inverse.f_direct().empty()) { + for (const auto* vars : {&inverse.f_direct(), &inverse.f_inverse()}) { + for (const int var : *vars) { + if (!VariableIndexIsValid(model, var)) { + return absl::StrCat("Invalid variable index in inverse constraint: ", + var); + } } } } + if (!inverse.f_expr_direct().empty()) { + for (const LinearExpressionProto& expr : inverse.f_expr_direct()) { + RETURN_IF_NOT_EMPTY(ValidateLinearExpression(model, expr)); + } + for (const LinearExpressionProto& expr : inverse.f_expr_inverse()) { + RETURN_IF_NOT_EMPTY(ValidateLinearExpression(model, expr)); + } + } return ""; } @@ -581,8 +600,7 @@ std::string ValidateAutomatonConstraint(const CpModelProto& model, const int64_t tail = automaton.transition_tail(i); const int64_t head = automaton.transition_head(i); const int64_t label = automaton.transition_label(i); - if (label <= std::numeric_limits::min() + 1 || - label == std::numeric_limits::max()) { + if (label <= kint64min + 1 || label == kint64max) { return absl::StrCat("labels in the automaton constraint are too big: ", label); } @@ -764,7 +782,7 @@ std::string ValidateCumulativeConstraint(const CpModelProto& model, const int64_t demand_max = MaxOfExpression(model, demand_expr); DCHECK_GE(demand_max, 0); sum_max_demands = CapAdd(sum_max_demands, demand_max); - if (sum_max_demands == std::numeric_limits::max()) { + if (sum_max_demands == kint64max) { return "The sum of max demands do not fit on an int64_t in constraint: " + ProtobufDebugString(ct); } @@ -790,7 +808,7 @@ std::string ValidateNoOverlap2DConstraint(const CpModelProto& model, const int64_t max_size_y = IntervalSizeMax(model, ct.no_overlap_2d().y_intervals(i)); sum_max_areas = CapAdd(sum_max_areas, CapProd(max_size_x, max_size_y)); - if (sum_max_areas == std::numeric_limits::max()) { + if (sum_max_areas == kint64max) { return "Integer overflow when summing all areas in " "constraint: " + ProtobufDebugString(ct); @@ -810,10 +828,8 @@ std::string ValidateReservoirConstraint(const CpModelProto& model, for (const LinearExpressionProto& expr : ct.reservoir().time_exprs()) { RETURN_IF_NOT_EMPTY(ValidateLinearExpression(model, expr)); // We want to be able to safely put time_exprs[i]-time_exprs[j] in a linear. - if (MinOfExpression(model, expr) <= - -std::numeric_limits::max() / 4 || - MaxOfExpression(model, expr) >= - std::numeric_limits::max() / 4) { + if (MinOfExpression(model, expr) <= -kint64max / 4 || + MaxOfExpression(model, expr) >= kint64max / 4) { return absl::StrCat( "Potential integer overflow on time_expr of a reservoir: ", ProtobufShortDebugString(ct)); @@ -841,7 +857,7 @@ std::string ValidateReservoirConstraint(const CpModelProto& model, const int64_t demand_min = MinOfExpression(model, demand); const int64_t demand_max = MaxOfExpression(model, demand); sum_abs = CapAdd(sum_abs, std::max(CapAbs(demand_min), CapAbs(demand_max))); - if (sum_abs == std::numeric_limits::max()) { + if (sum_abs == kint64max) { return "Possible integer overflow in constraint: " + ProtobufDebugString(ct); } @@ -1046,8 +1062,7 @@ std::string ValidateSolutionHint(const CpModelProto& model) { // Reject hints equals to INT_MIN or INT_MAX. for (const int64_t value : hint.values()) { - if (value == std::numeric_limits::min() || - value == std::numeric_limits::max()) { + if (value == kint64min || value == kint64max) { return "The solution hint cannot contains the INT_MIN or INT_MAX values."; } } @@ -1057,43 +1072,51 @@ std::string ValidateSolutionHint(const CpModelProto& model) { } // namespace +bool LinearOverflowChecker::AddTerm(int64_t coeff, int64_t min_domain, + int64_t max_domain) { + if (coeff == kint64min) return false; + const int64_t prod1 = CapProd(min_domain, coeff); + const int64_t prod2 = CapProd(max_domain, coeff); + if (AtMinOrMaxInt64(prod1)) return false; + if (AtMinOrMaxInt64(prod2)) return false; + + // Note that we use min/max with zero to disallow "alternative" terms and + // be sure that we cannot have an overflow if we do the computation in a + // different order. + sum_min = CapAdd(sum_min, std::min(int64_t{0}, std::min(prod1, prod2))); + sum_max = CapAdd(sum_max, std::max(int64_t{0}, std::max(prod1, prod2))); + + // In addition to computing the min/max possible sum, we also often compare + // it with the constraint bounds, so we do not want max - min to overflow. + // We might also create an intermediate variable to represent the sum. + // + // Note that it is important to be symmetric here, as we do not want expr to + // pass but not -expr! + return sum_min >= -kint64max / 2 && sum_max <= kint64max / 2; +} + bool PossibleIntegerOverflow(const CpModelProto& model, absl::Span vars, absl::Span coeffs, int64_t offset, std::pair* implied_domain) { - if (offset == std::numeric_limits::min()) return true; - int64_t sum_min = -std::abs(offset); - int64_t sum_max = +std::abs(offset); + offset = std::abs(offset); + if (offset > kint64max / 2) return true; + + LinearOverflowChecker checher; + checher.sum_min = -std::abs(offset); + checher.sum_max = std::abs(offset); for (int i = 0; i < vars.size(); ++i) { const int ref = vars[i]; - const auto& var_proto = model.variables(PositiveRef(ref)); + CHECK(RefIsPositive(ref)); + const auto& var_proto = model.variables(ref); const int64_t min_domain = var_proto.domain(0); const int64_t max_domain = var_proto.domain(var_proto.domain_size() - 1); - if (coeffs[i] == std::numeric_limits::min()) return true; - const int64_t coeff = RefIsPositive(ref) ? coeffs[i] : -coeffs[i]; - const int64_t prod1 = CapProd(min_domain, coeff); - const int64_t prod2 = CapProd(max_domain, coeff); - // Note that we use min/max with zero to disallow "alternative" terms and - // be sure that we cannot have an overflow if we do the computation in a - // different order. - sum_min = CapAdd(sum_min, std::min(int64_t{0}, std::min(prod1, prod2))); - sum_max = CapAdd(sum_max, std::max(int64_t{0}, std::max(prod1, prod2))); - for (const int64_t v : {prod1, prod2, sum_min, sum_max}) { - if (AtMinOrMaxInt64(v)) return true; - } + if (!checher.AddTerm(coeffs[i], min_domain, max_domain)) return true; } - // In addition to computing the min/max possible sum, we also often compare - // it with the constraint bounds, so we do not want max - min to overflow. - // We might also create an intermediate variable to represent the sum. - // - // Note that it is important to be symmetric here, as we do not want expr to - // pass but not -expr! - if (sum_min < -std::numeric_limits::max() / 2) return true; - if (sum_max > std::numeric_limits::max() / 2) return true; if (implied_domain) { - *implied_domain = {sum_min, sum_max}; + *implied_domain = {checher.sum_min, checher.sum_max}; } return false; } @@ -1113,7 +1136,7 @@ std::string ValidateCpModel(const CpModelProto& model, bool after_presolve) { // We require this precondition so that we can take any linear combination of // variable with coefficient in int64_t and compute the activity on an int128 // with no overflow. This is useful during cut computation. - if (int128_overflow == std::numeric_limits::max()) { + if (int128_overflow == kint64max) { return "The sum of all variable domains do not fit on an int64_t. This is " "needed to prevent overflows."; } @@ -1391,7 +1414,7 @@ class ConstraintChecker { bool LinMaxConstraintIsFeasible(const ConstraintProto& ct) { const int64_t max = LinearExpressionValue(ct.lin_max().target()); - int64_t actual_max = std::numeric_limits::min(); + int64_t actual_max = kint64min; for (int i = 0; i < ct.lin_max().exprs_size(); ++i) { const int64_t expr_value = LinearExpressionValue(ct.lin_max().exprs(i)); actual_max = std::max(actual_max, expr_value); @@ -1461,7 +1484,7 @@ class ConstraintChecker { } } std::sort(start_durations_pairs.begin(), start_durations_pairs.end()); - int64_t previous_end = std::numeric_limits::min(); + int64_t previous_end = kint64min; for (const auto& pair : start_durations_pairs) { if (pair.first < previous_end) return false; previous_end = pair.first + pair.second; @@ -1777,6 +1800,16 @@ class ConstraintChecker { if (fi < 0 || num_variables <= fi) return false; if (i != Value(ct.inverse().f_inverse(fi))) return false; } + const int num_exprs = ct.inverse().f_expr_direct_size(); + if (num_exprs != ct.inverse().f_expr_inverse_size()) return false; + // Check that f_expr_inverse(f_expr_direct(i)) == i; this is sufficient. + for (int i = 0; i < num_exprs; i++) { + const int fi = LinearExpressionValue(ct.inverse().f_expr_direct(i)); + if (fi < 0 || num_exprs <= fi) return false; + if (i != LinearExpressionValue(ct.inverse().f_expr_inverse(fi))) { + return false; + } + } return true; } diff --git a/ortools/sat/cp_model_checker.h b/ortools/sat/cp_model_checker.h index 7ce3f02c28f..03e19418242 100644 --- a/ortools/sat/cp_model_checker.h +++ b/ortools/sat/cp_model_checker.h @@ -49,9 +49,20 @@ std::string ValidateCpModel(const CpModelProto& model, std::string ValidateInputCpModel(const SatParameters& params, const CpModelProto& model); +// This is the logic used by PossibleIntegerOverflow(). +// It is exposed so that we can reuse it in more situations. +struct LinearOverflowChecker { + // Returns false if adding coeff * X for X in [min_domain, max_domain] to a + // linear equation does not satisfy our overflow preconditions. + bool AddTerm(int64_t coeff, int64_t min_domain, int64_t max_domain); + + int64_t sum_min = 0; + int64_t sum_max = 0; +}; + // Check if a given linear expression can create overflow. If it doesn't, -// sets `implied_domain` to the implied domain of the expression. It is exposed -// to test new constraints created during the presolve. +// sets `implied_domain` to a superset of the implied domain of the expression. +// It is exposed to test new constraints created during the presolve. bool PossibleIntegerOverflow( const CpModelProto& model, absl::Span vars, absl::Span coeffs, int64_t offset = 0, diff --git a/ortools/sat/cp_model_checker_test.cc b/ortools/sat/cp_model_checker_test.cc index 8d7d3439e32..6b88a7783d0 100644 --- a/ortools/sat/cp_model_checker_test.cc +++ b/ortools/sat/cp_model_checker_test.cc @@ -14,7 +14,6 @@ #include "ortools/sat/cp_model_checker.h" #include -#include #include #include "absl/log/check.h" @@ -22,6 +21,7 @@ #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/base/parse_test_proto.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" namespace operations_research { @@ -121,6 +121,49 @@ TEST(SolutionIsFeasibleTest, OrToolsIssue3769) { EXPECT_FALSE(SolutionIsFeasible(model, {2, 0})); } +TEST(SolutionIsFeasibleTest, LegacyInverse) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + inverse { + f_direct: 0 + f_direct: 1 + f_direct: 2 + f_inverse: 3 + f_inverse: 4 + f_inverse: 5 + } + } + )pb"); + EXPECT_TRUE(SolutionIsFeasible(model, {2, 0, 1, 1, 2, 0})); + EXPECT_FALSE(SolutionIsFeasible(model, {2, 0, 1, 1, 0, 2})); +} + +TEST(SolutionIsFeasibleTest, Inverse) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + inverse { + f_expr_direct: { vars: 0 coeffs: -1 offset: 3 } + f_expr_direct: { vars: 1 coeffs: 1 offset: -1 } + f_expr_direct: { vars: 2 coeffs: 3 offset: -2 } + f_expr_inverse: { vars: 0 coeffs: 1 offset: 0 } + f_expr_inverse: { vars: 1 coeffs: 1 offset: 1 } + f_expr_inverse: { vars: 2 coeffs: -2 offset: 2 } + } + } + )pb"); + EXPECT_TRUE(SolutionIsFeasible(model, {1, 1, 1})); + EXPECT_FALSE(SolutionIsFeasible(model, {1, 0, 2})); +} + TEST(SolutionIsFeasibleTest, Reservoir) { const CpModelProto model = ParseTestProto(R"pb( variables { domain: [ 0, 2 ] } @@ -289,8 +332,7 @@ TEST(ValidateCpModelTest, VariableLowerBoundTooLarge2) { } TEST(ValidateCpModelTest, VariableDomainOverflow) { - CHECK_EQ(std::numeric_limits::max() / 2, - int64_t{4611686018427387903}); + CHECK_EQ(kint64max / 2, int64_t{4611686018427387903}); const CpModelProto model_ok = ParseTestProto(R"pb( variables { @@ -311,8 +353,7 @@ TEST(ValidateCpModelTest, VariableDomainOverflow) { )pb"); EXPECT_THAT(ValidateCpModel(model_bad1), HasSubstr("do not fall in")); - CHECK_EQ(std::numeric_limits::min() + 2, - int64_t{-9223372036854775806}); + CHECK_EQ(kint64min + 2, int64_t{-9223372036854775806}); const CpModelProto model_bad2 = ParseTestProto(R"pb( variables { name: 'a' domain: -9223372036854775806 domain: 2 } )pb"); @@ -320,8 +361,7 @@ TEST(ValidateCpModelTest, VariableDomainOverflow) { } TEST(ValidateCpModelTest, ObjectiveOverflow) { - CHECK_EQ(std::numeric_limits::max() / 4, - int64_t{2305843009213693951}); + CHECK_EQ(kint64max / 4, int64_t{2305843009213693951}); const CpModelProto model = ParseTestProto(R"pb( variables { domain: [ -2305843009213693951, 2305843009213693951 ] } variables { domain: [ -2305843009213693951, 2305843009213693951 ] } @@ -603,6 +643,55 @@ TEST(ValidateCpModelTest, NegativeModulo) { HasSubstr("strictly positive modulo argument")); } +TEST(ValidateCpModelTest, InvalidLegacyInverseConstraint) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: 0 domain: 1 } + constraints { inverse { f_direct: [ 0 ] } } + )pb"); + EXPECT_THAT(ValidateCpModel(model), + HasSubstr("Non-matching fields size in inverse")); +} + +TEST(ValidateCpModelTest, InvalidInverseConstraint) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: 0 domain: 1 } + constraints { + inverse { + f_expr_direct: { + vars: [ 0 ] + coeffs: [ 1 ] + } + } + } + )pb"); + EXPECT_THAT(ValidateCpModel(model), + HasSubstr("Non-matching fields size in inverse")); +} + +TEST(ValidateCpModelTest, InconsistenInverseConstraint) { + const CpModelProto model = ParseTestProto(R"pb( + variables { domain: 0 domain: 1 } + constraints { + inverse { + f_direct: [ 0 ] + f_inverse: [ 0 ] + f_expr_direct: { + vars: [ 0 ] + coeffs: [ 1 ] + } + f_expr_inverse: { + vars: [ 0 ] + coeffs: [ 1 ] + } + } + } + )pb"); + EXPECT_THAT( + ValidateCpModel(model), + HasSubstr( + "Inconsistent inverse with both legacy and new format defined")); +} + TEST(ValidateCpModelTest, IncompatibleAutomatonTransitions) { const CpModelProto model = ParseTestProto(R"pb( variables { domain: 0 domain: 1 } diff --git a/ortools/sat/cp_model_copy.cc b/ortools/sat/cp_model_copy.cc index e14ee0fc563..5dd0f369288 100644 --- a/ortools/sat/cp_model_copy.cc +++ b/ortools/sat/cp_model_copy.cc @@ -17,18 +17,18 @@ #include #include #include -#include -#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/numeric/int128.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -38,59 +38,199 @@ #include "google/protobuf/text_format.h" #include "ortools/base/macros/os_support.h" #include "ortools/base/protobuf_util.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" -#include "ortools/sat/presolve_context.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/synchronization.h" +#include "ortools/util/logging.h" +#include "ortools/util/saturated_arithmetic.h" #include "ortools/util/sorted_interval_list.h" namespace operations_research { namespace sat { -namespace { -// This assumes an identity mapping between positive proto refs and Boolean -// variables (this might not be the case if the input proto contains non Boolean -// variables between Boolean ones). -Literal RefToLiteral(int ref) { - return Literal(BooleanVariable(PositiveRef(ref)), RefIsPositive(ref)); +bool ModelCopyHelper::InitializeDomains(std::vector domains, + absl::Span mapping) { + mapping_.assign(mapping.begin(), mapping.end()); + + input_variable_is_fixed_.assign(domains.size(), false); + input_variable_fixed_values_.resize(domains.size()); + for (int var = 0; var < domains.size(); ++var) { + if (domains[var].IsEmpty()) return false; + if (domains[var].IsFixed()) { + input_variable_is_fixed_[var] = true; + input_variable_fixed_values_[var] = domains[var].FixedValue(); + } + } + + if (mapping.empty()) { + mapped_domains_ = std::move(domains); + return true; + } + CHECK_EQ(mapping.size(), domains.size()); + + // Compute the range [0, mapped_size) of the mapped variables + int mapped_size = 0; + for (const int image : mapping) { + if (image == kNoVariableMapping) continue; + mapped_size = std::max(mapped_size, PositiveRef(image) + 1); + } + + // If many variables are mapped to the same one, the domain will be the + // intersection of all initial domains. Note, that we start with empty + // domains to mean unset. This works since after each operation if we are + // empty, we return false right away (UNSAT). + mapped_domains_.resize(mapped_size); + const auto update_domain = [this](int image, const Domain& domain) { + DCHECK(!domain.IsEmpty()); + Domain& mapped_domain = mapped_domains_[image]; + if (mapped_domain.IsEmpty()) { + mapped_domain = domain; + return true; + } else { + mapped_domain = mapped_domain.IntersectionWith(domain); + return !mapped_domain.IsEmpty(); + } + }; + + for (int i = 0; i < mapping.size(); ++i) { + if (mapping[i] == kNoVariableMapping) continue; + const Domain& domain = domains[i]; + const int image = PositiveRef(mapping[i]); + + // Special negative mapping for boolean. + if (!RefIsPositive(mapping[i])) { + DCHECK_GE(domain.Min(), 0); + DCHECK_LE(domain.Max(), 1); + if (domain.IsFixed()) { + const Domain fixed_value = Domain(1 - domain.FixedValue()); + if (!update_domain(image, fixed_value)) return false; + continue; + } + } + + if (!update_domain(image, domain)) return false; + } + + return true; } -int LiteralToRef(Literal lit) { - const int var = lit.Variable().value(); - return lit.IsPositive() ? var : NegatedRef(var); + +bool ModelCopyHelper::InputIsFixed(int ref) const { + const int var = PositiveRef(ref); + if (input_variable_is_fixed_[var]) return true; + + // If the mapped variable is fixed, mark the input as such. + // This make sure we use the latest state as we copy. + // + // TODO(user): This reflect more the state during copy, but it might be + // slower. So maybe we should leave that for presolve? On another hand, more + // than speed, it is the saving in memory that is interesting during the + // first copy. + const int image = mapping_.empty() ? var : mapping_[var]; + DCHECK_NE(image, kNoVariableMapping); + const Domain& mapped_domain = mapped_domains_[PositiveRef(image)]; + if (mapped_domain.IsFixed()) { + input_variable_is_fixed_[var] = true; + input_variable_fixed_values_[var] = RefIsPositive(image) + ? mapped_domain.FixedValue() + : 1 - mapped_domain.FixedValue(); + return true; + } + + return false; } -} // namespace -ModelCopy::ModelCopy(PresolveContext* context, - absl::Span variable_mapping, - absl::Span reverse_mapping) - : context_(context), - variable_mapping_(variable_mapping), - reverse_mapping_(reverse_mapping), - lrat_proof_handler_(context->lrat_proof_handler.get()) {} - -void ModelCopy::ImportVariablesAndMaybeIgnoreNames( - const CpModelProto& in_model) { - if (context_->params().ignore_names()) { - context_->working_model->clear_variables(); - context_->working_model->mutable_variables()->Reserve( - in_model.variables_size()); - for (const IntegerVariableProto& var_proto : in_model.variables()) { - *context_->working_model->add_variables()->mutable_domain() = - var_proto.domain(); +bool ModelCopyHelper::InputFixedLiteralIsTrue(int ref) const { + DCHECK(InputIsFixed(ref)); + const int value = RefIsPositive(ref) ? 1 : 0; + const int var = PositiveRef(ref); + return input_variable_fixed_values_[var] == value; +} + +int64_t ModelCopyHelper::InputFixedValue(int var) const { + DCHECK(InputIsFixed(var)); + DCHECK(RefIsPositive(var)); + return input_variable_fixed_values_[var]; +} + +std::optional ModelCopyHelper::InputFixedValueOrNullopt( + const LinearExpressionProto& expr) const { + int64_t result = expr.offset(); + for (int i = 0; i < expr.vars_size(); ++i) { + if (expr.coeffs(i) == 0) continue; + const int var = expr.vars(i); + if (!InputIsFixed(var)) return std::nullopt; + result += expr.coeffs(i) * InputFixedValue(var); + } + return result; +} + +std::optional +ModelCopyHelper::InputFixedLinearArgumentOrNullopt( + const LinearArgumentProto& linear_argument) const { + const auto target = InputFixedValueOrNullopt(linear_argument.target()); + if (target == std::nullopt) return std::nullopt; + FixedLinearArgument result; + result.target = target.value(); + result.exprs.resize(linear_argument.exprs_size()); + for (int i = 0; i < linear_argument.exprs_size(); ++i) { + const auto expr = InputFixedValueOrNullopt(linear_argument.exprs(i)); + if (expr == std::nullopt) return std::nullopt; + result.exprs[i] = expr.value(); + } + return result; +} + +int64_t ModelCopyHelper::MappedMinOf(const LinearExpressionProto& expr) const { + int64_t result = expr.offset(); + for (int i = 0; i < expr.vars_size(); ++i) { + const int64_t coeff = expr.coeffs(i); + if (coeff > 0) { + result += coeff * mapped_domains_[expr.vars(i)].Min(); + } else { + result += coeff * mapped_domains_[expr.vars(i)].Max(); } - } else { - *context_->working_model->mutable_variables() = in_model.variables(); } + return result; } -void ModelCopy::CreateVariablesFromDomains(absl::Span domains) { - for (const Domain& domain : domains) { - FillDomainInProto(domain, context_->working_model->add_variables()); +bool ModelCopyHelper::IntersectMappedDomainWith(int var, const Domain& domain) { + if (mapped_domains_[var].IsIncludedIn(domain)) return true; + UpdateRuleStats("domain: changed"); + + const Domain intersection = mapped_domains_[var].IntersectionWith(domain); + mapped_domains_[var] = intersection; + if (intersection.IsEmpty()) return false; + if (mapping_.empty()) { + solution_crush_.SetOrUpdateVarToDomain(var, intersection); } + return true; +} + +ModelCopy::ModelCopy(CpModelProto* out_proto, Model* model, + absl::Span variable_mapping) + : params_(*model->GetOrCreate()), + logger_(model->GetOrCreate()), + working_model_(out_proto), + variable_mapping_(variable_mapping), + lrat_proof_handler_(model->Mutable()) {} + +bool ModelCopy::ImportVariables(const CpModelProto& in_model) { + std::vector domains; + domains.reserve(in_model.variables().size()); + for (const IntegerVariableProto& var_proto : in_model.variables()) { + domains.push_back(ReadDomainFromProto(var_proto)); + } + return helper_.InitializeDomains(std::move(domains), variable_mapping_); +} + +bool ModelCopy::CreateVariablesFromDomains(absl::Span domains) { + std::vector copy(domains.begin(), domains.end()); + return helper_.InitializeDomains(std::move(copy), variable_mapping_); } // TODO(user): Merge with the phase 1 of the presolve code. @@ -101,9 +241,7 @@ void ModelCopy::CreateVariablesFromDomains(absl::Span domains) { bool ModelCopy::ImportAndSimplifyConstraints( const CpModelProto& in_model, bool first_copy, std::function active_constraints) { - context_->InitializeNewDomains(); - if (context_->ModelIsUnsat()) return false; - const bool ignore_names = context_->params().ignore_names(); + const bool ignore_names = params_.ignore_names(); // If first_copy is true, we reorder the scheduling constraint to be sure they // refer to interval before them. @@ -116,7 +254,7 @@ bool ModelCopy::ImportAndSimplifyConstraints( lrat_proof_handler_->proof_status()->SetMaxOneBasedCnfIndex( in_model.constraints_size()); } - if (first_copy && context_->params().cp_model_pure_sat_presolve()) { + if (first_copy && params_.cp_model_pure_sat_presolve()) { // In this case, just copy all the constraints as is (just with duplicate // literals removed). If LRAT is enabled we cannot simply drop a constraint, // otherwise the occurrence counts in the LRAT checker would be incorrect. @@ -125,11 +263,13 @@ bool ModelCopy::ImportAndSimplifyConstraints( // always use this direct copy when pure SAT presolve is enabled. const std::string error_msg = "cp_model_pure_sat_presolve can only be used with pure SAT problems."; - if (context_->working_model->has_objective()) { + if (working_model_->has_objective()) { LOG(FATAL) << error_msg; } for (int i = 0; i < in_model.variables_size(); ++i) { - if (!context_->CanBeUsedAsLiteral(i)) { + const auto& domain = in_model.variables(i).domain(); + if (domain.size() != 2 || domain[0] < 0 || domain[1] > 1) { + // This is not a Boolean. LOG(FATAL) << error_msg; } } @@ -149,7 +289,7 @@ bool ModelCopy::ImportAndSimplifyConstraints( const auto [it, inserted] = temp_literals_set_.insert(lit); if (inserted) temp_literals_.push_back(lit); } - context_->working_model->add_constraints() + working_model_->add_constraints() ->mutable_bool_or() ->mutable_literals() ->Add(temp_literals_.begin(), temp_literals_.end()); @@ -157,7 +297,7 @@ bool ModelCopy::ImportAndSimplifyConstraints( return true; } - starting_constraint_index_ = context_->working_model->constraints_size(); + starting_constraint_index_ = working_model_->constraints_size(); for (int c = 0; c < in_model.constraints_size(); ++c) { if (active_constraints != nullptr && !active_constraints(c)) { continue; @@ -186,9 +326,13 @@ bool ModelCopy::ImportAndSimplifyConstraints( break; case ConstraintProto::kBoolAnd: if (temp_enforcement_literals_.empty()) { + helper_.UpdateRuleStats("bool_and: non-reified"); for (const int lit : ct.bool_and().literals()) { - context_->UpdateRuleStats("bool_and: non-reified"); - if (!context_->SetLiteralToTrue(lit)) { + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) continue; + return CreateUnsatModel(c, ct); + } + if (!helper_.SetMappedLiteralToTrue(MapLiteral(lit))) { return CreateUnsatModel(c, ct); } } @@ -199,7 +343,8 @@ bool ModelCopy::ImportAndSimplifyConstraints( } break; case ConstraintProto::kLinear: - if (!CopyLinear(ct, /*canonicalize=*/first_copy)) { + if (!CopyLinear(ct, /*canonicalize=*/first_copy || + !variable_mapping_.empty())) { return CreateUnsatModel(c, ct); } break; @@ -305,6 +450,7 @@ bool ModelCopy::ImportAndSimplifyConstraints( if (first_copy) { ExpandNonAffineExpressions(); } + return true; } @@ -315,16 +461,86 @@ bool ModelCopy::ImportObjective(const CpModelProto& in_model) { return true; } +// We have two modes: +// - If variable_mapping is empty, we import the hint and load it +// in the solution_crush to be updated as we copy. +// - Otherwise, we assume solution_crush is not needed, and we copy right +// away the hint in the remapped format. void ModelCopy::ImportSolutionHint(const CpModelProto& in_model) { - if (in_model.has_solution_hint()) { - CopySolutionHint(in_model.solution_hint()); + if (!in_model.has_solution_hint()) return; + const PartialVariableAssignment& hint = in_model.solution_hint(); + PartialVariableAssignment& new_hint = + *working_model_->mutable_solution_hint(); + + // Copy and remap. + // + // Note that we will clamp the hint in FinishCopy() + // since the domain might be more precise then. + int num_clamped = 0; + std::vector hint_added(helper_.MappedDomains().size(), false); + for (int i = 0; i < hint.vars_size(); ++i) { + int mapped_ref = hint.vars(i); + if (!variable_mapping_.empty()) { + mapped_ref = variable_mapping_[mapped_ref]; + if (mapped_ref == kNoVariableMapping) continue; + } + + const int mapped_var = PositiveRef(mapped_ref); + if (hint_added[mapped_var]) continue; + hint_added[mapped_var] = true; + + int64_t hint_value = hint.values(i); + if (!RefIsPositive(mapped_ref)) { + // We alread checked that this must have been a literal. + // Note however that the hint value is from outside the solver, so + // we distinguish true / false with hint_value > 0 and we inverse it. + if (hint_value > 0) { + hint_value = 0; + } else { + hint_value = 1; + } + } + + // We also clamp it. + const int64_t closest_domain_value = + helper_.MappedDomain(mapped_var).ClosestValue(hint_value); + if (closest_domain_value != hint_value) { + ++num_clamped; + hint_value = closest_domain_value; + } + + new_hint.add_vars(mapped_var); + new_hint.add_values(hint_value); + } + if (num_clamped > 0) { + helper_.UpdateRuleStats("hint: moved var hint within its domain", + num_clamped); + } + + if (variable_mapping_.empty()) { + // Load the solution hint in the solution_crush so that it can be updated + // when new variables are added, for instance in + // ExpandNonAffineExpressions(). This is not needed if there is a variable + // mapping, because such mappings are only used on models where non affine + // expressions have already been expanded. + absl::flat_hash_map hint_values; + const int num_vars = helper_.MappedDomains().size(); + for (int i = 0; i < new_hint.vars().size(); ++i) { + hint_values[new_hint.vars(i)] = new_hint.values(i); + } + for (int i = 0; i < num_vars; ++i) { + if (helper_.InputIsFixed(i)) { + hint_values.insert({i, helper_.InputFixedValue(i)}); + } + } + helper_.solution_crush()->LoadSolution(num_vars, hint_values); } } bool ModelCopy::ImportEverythingExceptVariablesConstraintsAndHint( const CpModelProto& in_model, bool copy_symmetry) { if (!in_model.name().empty()) { - context_->working_model->set_name(in_model.name()); + working_model_->set_name(in_model.name()); } if (in_model.has_objective()) { if (!CopyObjective(in_model.objective())) return false; @@ -333,10 +549,9 @@ bool ModelCopy::ImportEverythingExceptVariablesConstraintsAndHint( CopyFloatingPointObjective(in_model.floating_point_objective()); } if (!in_model.search_strategy().empty()) { - *context_->working_model->mutable_search_strategy() = - in_model.search_strategy(); + *working_model_->mutable_search_strategy() = in_model.search_strategy(); for (DecisionStrategyProto& strategy : - *context_->working_model->mutable_search_strategy()) { + *working_model_->mutable_search_strategy()) { google::protobuf::util::RemoveIf(strategy.mutable_exprs(), [](const LinearExpressionProto* expr) { return expr->vars().empty(); @@ -366,76 +581,68 @@ bool ModelCopy::ImportEverythingExceptVariablesConstraintsAndHint( } if (!in_model.assumptions().empty()) { for (const int lit : in_model.assumptions()) { - context_->working_model->add_assumptions(MapLiteral(lit)); + working_model_->add_assumptions(MapLiteralEvenIfFixed(lit)); } } if (in_model.has_symmetry() && copy_symmetry) { CHECK(variable_mapping_.empty()); - *context_->working_model->mutable_symmetry() = in_model.symmetry(); + *working_model_->mutable_symmetry() = in_model.symmetry(); } return true; } -bool ModelCopy::RemapVariablesInProtoAndContext() { - if (variable_mapping_.empty()) { - for (int i = 0; i < context_->working_model->variables_size(); ++i) { - FillDomainInProto(context_->DomainOf(i), - context_->working_model->mutable_variables(i)); - } - return true; +bool ModelCopy::FinishCopy(const CpModelProto& in_model) { + // Copy the mapped domains to the proto. + absl::Span domains = helper_.MappedDomains(); + working_model_->mutable_variables()->Clear(); + working_model_->mutable_variables()->Reserve(domains.size()); + for (const Domain& domain : domains) { + FillDomainInProto(domain, working_model_->add_variables()); } - // Make sure that equivalent variables have the same value if one of them is - // fixed (some variables can be fixed in ImportAndSimplifyConstraints, but the - // equivalent ones are not necessarily fixed too). - auto get_representative = [&](int var) { - const int mapped_ref = variable_mapping_[var]; - if (mapped_ref == kNoVariableMapping) return var; - const int rep = reverse_mapping_[PositiveRef(mapped_ref)]; - return RefIsPositive(mapped_ref) ? rep : NegatedRef(rep); - }; - auto fix_to_value_of = [&](int dst_ref, int src_ref) { - if (!context_->IsFixed(src_ref)) return true; - const int src_var = PositiveRef(src_ref); - const int dst_var = PositiveRef(dst_ref); - const int64_t src_var_value = context_->FixedValue(src_var); - if (RefIsPositive(dst_ref) == RefIsPositive(src_ref)) { - return context_->IntersectDomainWith(dst_var, Domain(src_var_value)); + + // Copy the names if requested. + if (!params_.ignore_names()) { + const int num_original_vars = in_model.variables().size(); + if (variable_mapping_.empty()) { + for (int i = 0; i < num_original_vars; ++i) { + working_model_->mutable_variables(i)->set_name( + in_model.variables(i).name()); + } } else { - // Only Boolean variables can be mapped to a negative ref. - DCHECK(src_var_value == 0 || src_var_value == 1); - return context_->IntersectDomainWith(dst_var, Domain(1 - src_var_value)); + for (int i = 0; i < num_original_vars; ++i) { + const auto& name = in_model.variables(i).name(); + if (name.empty()) continue; + if (variable_mapping_[i] == kNoVariableMapping) continue; + + // We keep the name of the variable with smallest index if non-empty. + const int image = PositiveRef(variable_mapping_[i]); + if (working_model_->variables(image).name().empty()) { + working_model_->mutable_variables(image)->set_name(name); + } + } } - }; - const int num_vars = context_->working_model->variables_size(); - for (int i = 0; i < num_vars; ++i) { - const int rep = get_representative(i); - if (rep != i && !fix_to_value_of(rep, i)) return false; } - for (int i = 0; i < num_vars; ++i) { - const int rep = get_representative(i); - if (rep != i && !fix_to_value_of(i, rep)) return false; - } - context_->working_model->mutable_variables()->Clear(); - context_->working_model->mutable_variables()->Reserve( - reverse_mapping_.size()); - for (int i = 0; i < reverse_mapping_.size(); ++i) { - FillDomainInProto(MappedVarDomain(i), - context_->working_model->add_variables()); - } - context_->ResetAfterCopy(); - context_->InitializeNewDomains(); + + // When there is no mapping, we might have created new variable and updated + // the hint, so write it back. Note that we DCHECK() that we don't create + // new variable when the mapping is non-empty. + if (variable_mapping_.empty()) { + helper_.solution_crush()->StoreSolutionAsHint(*working_model_); + } + + helper_.DisplaySummary(logger_); return true; } bool ModelCopy::PrepareEnforcementCopy(const ConstraintProto& ct) { temp_enforcement_literals_.clear(); for (const int lit : ct.enforcement_literal()) { - if (context_->LiteralIsTrue(lit)) continue; - if (context_->LiteralIsFalse(lit)) { - context_->UpdateRuleStats("enforcement: always false"); + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) continue; + helper_.UpdateRuleStats("enforcement: always false"); return false; } - temp_enforcement_literals_.push_back(MapRef(lit)); + temp_enforcement_literals_.push_back(MapLiteral(lit)); } return true; // Continue processing. } @@ -444,19 +651,18 @@ bool ModelCopy::PrepareEnforcementCopyWithDup(const ConstraintProto& ct) { temp_enforcement_literals_.clear(); temp_enforcement_literals_set_.clear(); for (const int lit : ct.enforcement_literal()) { - if (context_->LiteralIsTrue(lit)) continue; - // Cannot be satisfied. - if (context_->LiteralIsFalse(lit)) { - context_->UpdateRuleStats("enforcement: always false"); + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) continue; + helper_.UpdateRuleStats("enforcement: always false"); return false; } - const int mapped_lit = MapRef(lit); + const int mapped_lit = MapLiteral(lit); if (temp_enforcement_literals_set_.contains(mapped_lit)) { - context_->UpdateRuleStats("enforcement: removed duplicate literal"); + helper_.UpdateRuleStats("enforcement: removed duplicate literal"); continue; } if (temp_enforcement_literals_set_.contains(NegatedRef(mapped_lit))) { - context_->UpdateRuleStats("enforcement: contains x and not(x)"); + helper_.UpdateRuleStats("enforcement: contains x and not(x)"); return false; } @@ -475,15 +681,23 @@ bool ModelCopy::FinishBoolOrCopy() { if (temp_literals_.empty()) return false; if (temp_literals_.size() == 1) { - context_->UpdateRuleStats("bool_or: only one literal"); - const int lit = ReverseMapRef(temp_literals_[0]); - return context_->SetLiteralToTrue(lit); + helper_.UpdateRuleStats("bool_or: only one literal"); + return helper_.SetMappedLiteralToTrue(temp_literals_[0]); } - context_->working_model->add_constraints() - ->mutable_bool_or() - ->mutable_literals() - ->Add(temp_literals_.begin(), temp_literals_.end()); + working_model_->add_constraints()->mutable_bool_or()->mutable_literals()->Add( + temp_literals_.begin(), temp_literals_.end()); + return true; +} + +bool ModelCopy::CopyFalseConstraint() { + if (temp_enforcement_literals_.empty()) return false; + google::protobuf::RepeatedField& literals = + *working_model_->add_constraints()->mutable_bool_or()->mutable_literals(); + literals.Reserve(temp_enforcement_literals_.size()); + for (const int literal : temp_enforcement_literals_) { + literals.Add(NegatedRef(literal)); + } return true; } @@ -493,16 +707,32 @@ bool ModelCopy::CopyBoolOr(const ConstraintProto& ct) { temp_literals_.push_back(NegatedRef(lit)); } for (const int lit : ct.bool_or().literals()) { - if (context_->LiteralIsTrue(lit)) { - return true; - } - if (!context_->LiteralIsFalse(lit)) { - temp_literals_.push_back(MapRef(lit)); + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) return true; + continue; } + temp_literals_.push_back(MapLiteral(lit)); } return FinishBoolOrCopy(); } +namespace { + +// This is only used for LRAT proof support for pure-sat problem. +// +// This assumes an identity mapping between positive proto refs and Boolean +// variables (this might not be the case if the input proto contains non Boolean +// variables between Boolean ones). +Literal RefToLiteral(int ref) { + return Literal(BooleanVariable(PositiveRef(ref)), RefIsPositive(ref)); +} +int LiteralToRef(Literal lit) { + const int var = lit.Variable().value(); + return lit.IsPositive() ? var : NegatedRef(var); +} + +} // namespace + bool ModelCopy::CopyBoolOrWithDupSupport(const ConstraintProto& ct, int one_based_cnf_index) { temp_literals_.clear(); @@ -518,14 +748,17 @@ bool ModelCopy::CopyBoolOrWithDupSupport(const ConstraintProto& ct, temp_literals_.push_back(lit); } for (const int lit : ct.bool_or().literals()) { - if (context_->LiteralIsTrue(lit)) { - context_->UpdateRuleStats("bool_or: always true"); - return true; + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) { + helper_.UpdateRuleStats("bool_or: always true"); + return true; + } + continue; } - if (context_->LiteralIsFalse(lit)) continue; - const int mapped_lit = MapRef(lit); + + const int mapped_lit = MapLiteral(lit); if (temp_literals_set_.contains(NegatedRef(mapped_lit))) { - context_->UpdateRuleStats("bool_or: always true"); + helper_.UpdateRuleStats("bool_or: always true"); return true; } const auto [it, inserted] = temp_literals_set_.insert(mapped_lit); @@ -575,31 +808,24 @@ bool ModelCopy::CopyBoolAnd(const ConstraintProto& ct) { bool at_least_one_false = false; int num_non_fixed_literals = 0; for (const int lit : ct.bool_and().literals()) { - if (context_->LiteralIsFalse(lit)) { + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) continue; at_least_one_false = true; break; } - if (!context_->LiteralIsTrue(lit)) { - num_non_fixed_literals++; - } + num_non_fixed_literals++; } if (at_least_one_false) { - // One enforcement literal must be false. - BoolArgumentProto* bool_or = - context_->working_model->add_constraints()->mutable_bool_or(); - for (const int lit : temp_enforcement_literals_) { - bool_or->add_literals(NegatedRef(lit)); - } - return !bool_or->literals().empty(); + return CopyFalseConstraint(); } else if (num_non_fixed_literals > 0) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); BoolArgumentProto* bool_and = new_ct->mutable_bool_and(); bool_and->mutable_literals()->Reserve(num_non_fixed_literals); for (const int lit : ct.bool_and().literals()) { - if (context_->LiteralIsTrue(lit)) continue; - bool_and->add_literals(MapRef(lit)); + if (helper_.InputIsFixed(lit)) continue; // we tested above for false. + bool_and->add_literals(MapLiteral(lit)); } } return true; @@ -612,26 +838,27 @@ bool ModelCopy::CopyBoolAndWithDupSupport(const ConstraintProto& ct) { temp_literals_.clear(); temp_literals_set_.clear(); for (const int lit : ct.bool_and().literals()) { - if (context_->LiteralIsFalse(lit)) { - context_->UpdateRuleStats("bool and: always false"); + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) continue; + helper_.UpdateRuleStats("bool and: always false"); at_least_one_false = true; break; } - if (context_->LiteralIsTrue(lit)) continue; - const int mapped_lit = MapRef(lit); + + const int mapped_lit = MapLiteral(lit); if (temp_literals_set_.contains(NegatedRef(mapped_lit))) { - context_->UpdateRuleStats("bool and: => x and not(x) "); + helper_.UpdateRuleStats("bool and: => x and not(x) "); at_least_one_false = true; break; } if (temp_enforcement_literals_set_.contains(NegatedRef(mapped_lit))) { - context_->UpdateRuleStats("bool and: not(x) => x"); + helper_.UpdateRuleStats("bool and: not(x) => x"); at_least_one_false = true; break; } if (temp_enforcement_literals_set_.contains(mapped_lit)) { - context_->UpdateRuleStats("bool and: x => x"); + helper_.UpdateRuleStats("bool and: x => x"); continue; } const auto [it, inserted] = temp_literals_set_.insert(mapped_lit); @@ -639,42 +866,41 @@ bool ModelCopy::CopyBoolAndWithDupSupport(const ConstraintProto& ct) { } if (at_least_one_false) { - // One enforcement literal must be false. - BoolArgumentProto* bool_or = - context_->working_model->add_constraints()->mutable_bool_or(); - for (const int lit : temp_enforcement_literals_) { - bool_or->add_literals(NegatedRef(lit)); - } - return !bool_or->literals().empty(); + return CopyFalseConstraint(); } if (temp_literals_.empty()) { - context_->UpdateRuleStats("bool and: empty"); + helper_.UpdateRuleStats("bool and: empty"); return true; } // Copy. - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); new_ct->mutable_bool_and()->mutable_literals()->Add(temp_literals_.begin(), temp_literals_.end()); return true; } -bool ModelCopy::CopyLinearExpression( - const LinearExpressionProto& expr, LinearExpressionProto* dst, - const absl::flat_hash_set* mapped_enforcement_literals) { +template +int64_t ModelCopy::FillNonFixedTermsAndReturnOffset( + const T& proto_with_vars_and_coeffs, int64_t offset) { non_fixed_terms_.clear(); - int64_t offset = expr.offset(); - for (int i = 0; i < expr.vars_size(); ++i) { - int ref = expr.vars(i); - int64_t coeff = expr.coeffs(i); + for (int i = 0; i < proto_with_vars_and_coeffs.vars_size(); ++i) { + int ref = proto_with_vars_and_coeffs.vars(i); + int64_t coeff = proto_with_vars_and_coeffs.coeffs(i); MapTerm(ref, coeff, offset); if (coeff == 0) continue; DCHECK(RefIsPositive(ref)); non_fixed_terms_.push_back({ref, coeff}); } - // TODO(user): We could save work by doing this only for the first copy. + return offset; +} + +bool ModelCopy::CopyLinearExpression( + const LinearExpressionProto& expr, LinearExpressionProto* dst, + const absl::flat_hash_set* mapped_enforcement_literals) { + int64_t offset = FillNonFixedTermsAndReturnOffset(expr, expr.offset()); CanonicalizeLinearExpression(mapped_enforcement_literals, non_fixed_terms_, offset); @@ -687,27 +913,19 @@ bool ModelCopy::CopyLinearExpression( } bool ModelCopy::CopyLinear(const ConstraintProto& ct, bool canonicalize) { - non_fixed_terms_.clear(); - int64_t offset = 0; - for (int i = 0; i < ct.linear().vars_size(); ++i) { - int ref = ct.linear().vars(i); - int64_t coeff = ct.linear().coeffs(i); - MapTerm(ref, coeff, offset); - if (coeff == 0) continue; - DCHECK(RefIsPositive(ref)); - non_fixed_terms_.push_back({ref, coeff}); - } - + int64_t offset = FillNonFixedTermsAndReturnOffset(ct.linear()); if (canonicalize) { + // TODO(user): In practice we always do because this is either user-given + // or we use a variable_mapping_ that might merge variables. CanonicalizeLinearExpression(&temp_enforcement_literals_set_, non_fixed_terms_, offset); } int64_t min_activity = 0; int64_t max_activity = 0; - for (const auto& [ref, coeff] : non_fixed_terms_) { + for (const auto& [var, coeff] : non_fixed_terms_) { DCHECK_NE(coeff, 0); - const Domain& domain = MappedVarDomain(ref); + const Domain& domain = helper_.MappedDomain(var); if (coeff > 0) { min_activity += coeff * domain.Min(); max_activity += coeff * domain.Max(); @@ -722,48 +940,30 @@ bool ModelCopy::CopyLinear(const ConstraintProto& ct, bool canonicalize) { // Trivial constraint? if (implied.IsIncludedIn(new_rhs)) { - context_->UpdateRuleStats("linear: always true"); + helper_.UpdateRuleStats("linear: always true"); return true; } // Constraint is false? const Domain tight_domain = implied.IntersectionWith(new_rhs); if (tight_domain.IsEmpty()) { - if (temp_enforcement_literals_.empty()) return false; - temp_literals_.clear(); - for (const int literal : temp_enforcement_literals_) { - temp_literals_.push_back(NegatedRef(literal)); - } - context_->working_model->add_constraints() - ->mutable_bool_or() - ->mutable_literals() - ->Add(temp_literals_.begin(), temp_literals_.end()); - return true; + return CopyFalseConstraint(); } DCHECK(!non_fixed_terms_.empty()); if (non_fixed_terms_.size() == 1 && temp_enforcement_literals_.empty()) { - context_->UpdateRuleStats("linear1: x in domain"); - auto [single_var, coeff] = non_fixed_terms_[0]; - Domain new_var_domain = new_rhs.InverseMultiplicationBy(coeff); - if (!variable_mapping_.empty()) { - single_var = reverse_mapping_[single_var]; - if (!RefIsPositive(single_var)) { - // A variable can only be reverse mapped to a negative variable - // reference if it is a Boolean variable, in which case the remapped - // domain is 1 minus the original domain. - DCHECK(context_->CanBeUsedAsLiteral(single_var)); - single_var = NegatedRef(single_var); - new_var_domain = new_var_domain.Negation().AdditionWith(Domain(1)); - } - } - return context_->IntersectDomainWith(single_var, new_var_domain); + helper_.UpdateRuleStats("linear1: x in domain"); + const auto [single_var, coeff] = non_fixed_terms_[0]; + const Domain new_var_domain = new_rhs.InverseMultiplicationBy(coeff); + return helper_.IntersectMappedDomainWith(single_var, new_var_domain); } - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); LinearConstraintProto* linear = new_ct->mutable_linear(); + linear->mutable_vars()->Reserve(non_fixed_terms_.size()); + linear->mutable_coeffs()->Reserve(non_fixed_terms_.size()); for (const auto& [var, coeff] : non_fixed_terms_) { linear->add_vars(var); linear->add_coeffs(coeff); @@ -775,7 +975,7 @@ bool ModelCopy::CopyLinear(const ConstraintProto& ct, bool canonicalize) { template void ModelCopy::CanonicalizeLinearExpression( const absl::flat_hash_set* enforcement_literals, - std::vector>& terms, T& offset) const { + std::vector>& terms, T& offset) { // Merge terms with the same variable, remove terms with a zero coefficient // and replace the enforcement literals with their value. int current_var = 0; @@ -786,11 +986,11 @@ void ModelCopy::CanonicalizeLinearExpression( if (enforcement_literals->contains(current_var)) { // If the constraint is enforced, we can assume the variable is at 1. offset += current_coeff; - context_->UpdateRuleStats("linear: enforcement literal in expression"); + helper_.UpdateRuleStats("linear: enforcement literal in expression"); return; } else if (enforcement_literals->contains(NegatedRef(current_var))) { // We can assume the variable is at 0. - context_->UpdateRuleStats("linear: enforcement literal in expression"); + helper_.UpdateRuleStats("linear: enforcement literal in expression"); return; } } @@ -812,13 +1012,33 @@ void ModelCopy::CanonicalizeLinearExpression( maybe_add_current_term(); } if (new_size < terms.size()) { - context_->UpdateRuleStats("linear: fixed or dup variables"); + helper_.UpdateRuleStats("linear: fixed or dup variables"); } terms.resize(new_size); } +void ModelCopy::ConvertSingleVarFormatToExpr(int var, + LinearExpressionProto* expr) { + if (helper_.InputIsFixed(var)) { + expr->set_offset(helper_.InputFixedValue(var)); + } else { + DCHECK(RefIsPositive(var)); + expr->mutable_vars()->Reserve(1); + expr->mutable_coeffs()->Reserve(1); + expr->add_coeffs(1); + if (!variable_mapping_.empty()) { + // We actually only encounter the old format with a mapping in tests. + // And in this case, we just have positive maps. + var = variable_mapping_[var]; + DCHECK_NE(var, kNoVariableMapping); + DCHECK(RefIsPositive(var)); + } + expr->add_vars(var); + } +} + bool ModelCopy::CopyElement(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); if (ct.element().vars().empty() && !ct.element().exprs().empty()) { // New format, just copy and remap variables. FinishEnforcementCopy(new_ct); @@ -832,31 +1052,20 @@ bool ModelCopy::CopyElement(const ConstraintProto& ct) { return true; } - auto fill_expr = [this](int var, LinearExpressionProto* expr) mutable { - if (context_->IsFixed(var)) { - expr->set_offset(context_->FixedValue(var)); - } else { - DCHECK(RefIsPositive(var)); - expr->mutable_vars()->Reserve(1); - expr->mutable_coeffs()->Reserve(1); - expr->add_vars(MapRef(var)); - expr->add_coeffs(1); - } - }; - FinishEnforcementCopy(new_ct); - fill_expr(ct.element().index(), - new_ct->mutable_element()->mutable_linear_index()); - fill_expr(ct.element().target(), - new_ct->mutable_element()->mutable_linear_target()); + ConvertSingleVarFormatToExpr( + ct.element().index(), new_ct->mutable_element()->mutable_linear_index()); + ConvertSingleVarFormatToExpr( + ct.element().target(), + new_ct->mutable_element()->mutable_linear_target()); for (const int var : ct.element().vars()) { - fill_expr(var, new_ct->mutable_element()->add_exprs()); + ConvertSingleVarFormatToExpr(var, new_ct->mutable_element()->add_exprs()); } return true; } bool ModelCopy::CopyAutomaton(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); new_ct->mutable_automaton()->set_starting_state( ct.automaton().starting_state()); *new_ct->mutable_automaton()->mutable_final_states() = @@ -872,27 +1081,15 @@ bool ModelCopy::CopyAutomaton(const ConstraintProto& ct) { } FinishEnforcementCopy(new_ct); - auto fill_expr = [this](int var, LinearExpressionProto* expr) mutable { - if (context_->IsFixed(var)) { - expr->set_offset(context_->FixedValue(var)); - } else { - DCHECK(RefIsPositive(var)); - expr->mutable_vars()->Reserve(1); - expr->mutable_coeffs()->Reserve(1); - expr->add_vars(MapRef(var)); - expr->add_coeffs(1); - } - }; - for (const int var : ct.automaton().vars()) { - fill_expr(var, new_ct->mutable_automaton()->add_exprs()); + ConvertSingleVarFormatToExpr(var, new_ct->mutable_automaton()->add_exprs()); } return true; } bool ModelCopy::CopyTable(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); if (ct.table().vars().empty() && !ct.table().exprs().empty()) { // New format, just copy and remap variables. FinishEnforcementCopy(new_ct); @@ -904,21 +1101,9 @@ bool ModelCopy::CopyTable(const ConstraintProto& ct) { return true; } - auto fill_expr = [this](int var, LinearExpressionProto* expr) mutable { - if (context_->IsFixed(var)) { - expr->set_offset(context_->FixedValue(var)); - } else { - DCHECK(RefIsPositive(var)); - expr->mutable_vars()->Reserve(1); - expr->mutable_coeffs()->Reserve(1); - expr->add_vars(MapRef(var)); - expr->add_coeffs(1); - } - }; - FinishEnforcementCopy(new_ct); for (const int var : ct.table().vars()) { - fill_expr(var, new_ct->mutable_table()->add_exprs()); + ConvertSingleVarFormatToExpr(var, new_ct->mutable_table()->add_exprs()); } *new_ct->mutable_table()->mutable_values() = ct.table().values(); new_ct->mutable_table()->set_negated(ct.table().negated()); @@ -928,7 +1113,7 @@ bool ModelCopy::CopyTable(const ConstraintProto& ct) { bool ModelCopy::CopyAllDiff(const ConstraintProto& ct) { if (ct.all_diff().exprs().size() <= 1) return true; - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); for (const LinearExpressionProto& expr : ct.all_diff().exprs()) { CopyLinearExpression(expr, new_ct->mutable_all_diff()->add_exprs()); } @@ -937,19 +1122,32 @@ bool ModelCopy::CopyAllDiff(const ConstraintProto& ct) { } bool ModelCopy::CopyLinMax(const ConstraintProto& ct) { + // A lin max must have some rhs, or it is not satisfiable. + if (ct.lin_max().exprs().empty()) return CopyFalseConstraint(); + + // Check if everything is fixed. + const auto fixed_linear_argument = + helper_.InputFixedLinearArgumentOrNullopt(ct.lin_max()); + if (fixed_linear_argument.has_value()) { + DCHECK(!fixed_linear_argument->exprs.empty()); + return CopyTrivialConstraint( + *absl::c_max_element(fixed_linear_argument->exprs) == + fixed_linear_argument->target); + } + // We will create it lazily if we end up copying something. ConstraintProto* new_ct = nullptr; // Regroup all constant terms and copy the other. - int64_t max_of_fixed_terms = std::numeric_limits::min(); + int64_t max_of_fixed_terms = kint64min; for (const auto& expr : ct.lin_max().exprs()) { - const std::optional fixed = context_->FixedValueOrNullopt(expr); + const std::optional fixed = helper_.InputFixedValueOrNullopt(expr); if (fixed != std::nullopt) { max_of_fixed_terms = std::max(max_of_fixed_terms, fixed.value()); } else { // copy. if (new_ct == nullptr) { - new_ct = context_->working_model->add_constraints(); + new_ct = working_model_->add_constraints(); } CopyLinearExpression(expr, new_ct->mutable_lin_max()->add_exprs()); } @@ -957,17 +1155,43 @@ bool ModelCopy::CopyLinMax(const ConstraintProto& ct) { // If we have no non-fixed expression, we can just fix the target when it // involves at most one variable. + const LinearExpressionProto& target = ct.lin_max().target(); if (new_ct == nullptr && temp_enforcement_literals_.empty() && - ct.lin_max().target().vars().size() <= 1) { - context_->UpdateRuleStats("lin_max: all exprs fixed during copy"); - return context_->IntersectDomainWith(ct.lin_max().target(), - Domain(max_of_fixed_terms)); + target.vars().size() <= 1) { + helper_.UpdateRuleStats("lin_max: all exprs fixed during copy"); + + // coeff * var + offset == max_of_fixed_terms. + absl::int128 rhs = absl::int128(max_of_fixed_terms) - target.offset(); + if (target.vars().empty()) { + return rhs == 0; // Unsat otherwise. + } + + int var = target.vars(0); + int64_t coeff = target.coeffs(0); + int64_t offset = 0; + MapTerm(var, coeff, offset); + rhs -= offset; + + if (coeff == 0) { + return rhs == 0; // Unsat otherwise. + } + + if (rhs % coeff != 0) return false; + rhs /= coeff; + if (rhs > absl::int128(kint64max) || rhs < absl::int128(kint64min)) { + helper_.UpdateRuleStats( + "lin_max: target must take value outside int64_t range"); + return false; + } + + return helper_.IntersectMappedDomainWith(var, + Domain(static_cast(rhs))); } // Otherwise, add a constant term if needed. - if (max_of_fixed_terms > std::numeric_limits::min()) { + if (max_of_fixed_terms > kint64min) { if (new_ct == nullptr) { - new_ct = context_->working_model->add_constraints(); + new_ct = working_model_->add_constraints(); } new_ct->mutable_lin_max()->add_exprs()->set_offset(max_of_fixed_terms); } @@ -981,27 +1205,27 @@ bool ModelCopy::CopyLinMax(const ConstraintProto& ct) { } bool ModelCopy::CopyCircuit(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); *new_ct->mutable_circuit()->mutable_tails() = ct.circuit().tails(); *new_ct->mutable_circuit()->mutable_heads() = ct.circuit().heads(); new_ct->mutable_circuit()->mutable_literals()->Reserve( ct.circuit().literals_size()); for (const int lit : ct.circuit().literals()) { - new_ct->mutable_circuit()->add_literals(MapLiteral(lit)); + new_ct->mutable_circuit()->add_literals(MapLiteralEvenIfFixed(lit)); } return true; } bool ModelCopy::CopyRoutes(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); *new_ct->mutable_routes()->mutable_tails() = ct.routes().tails(); *new_ct->mutable_routes()->mutable_heads() = ct.routes().heads(); new_ct->mutable_routes()->mutable_literals()->Reserve( ct.routes().literals_size()); for (const int lit : ct.routes().literals()) { - new_ct->mutable_routes()->add_literals(MapLiteral(lit)); + new_ct->mutable_routes()->add_literals(MapLiteralEvenIfFixed(lit)); } new_ct->mutable_routes()->mutable_dimensions()->Reserve( ct.routes().dimensions_size()); @@ -1018,23 +1242,36 @@ bool ModelCopy::CopyRoutes(const ConstraintProto& ct) { } bool ModelCopy::CopyInverse(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); + if (ct.inverse().f_direct().empty() && + !ct.inverse().f_expr_direct().empty()) { + // New format, just copy and remap variables. + FinishEnforcementCopy(new_ct); + for (const LinearExpressionProto& expr : ct.inverse().f_expr_direct()) { + CopyLinearExpression(expr, + new_ct->mutable_inverse()->add_f_expr_direct()); + } + for (const LinearExpressionProto& expr : ct.inverse().f_expr_inverse()) { + CopyLinearExpression(expr, + new_ct->mutable_inverse()->add_f_expr_inverse()); + } + return true; + } + FinishEnforcementCopy(new_ct); - new_ct->mutable_inverse()->mutable_f_direct()->Reserve( - ct.inverse().f_direct_size()); for (const int f : ct.inverse().f_direct()) { - new_ct->mutable_inverse()->add_f_direct(MapRef(f)); + ConvertSingleVarFormatToExpr( + f, new_ct->mutable_inverse()->add_f_expr_direct()); } - new_ct->mutable_inverse()->mutable_f_inverse()->Reserve( - ct.inverse().f_inverse_size()); for (const int f : ct.inverse().f_inverse()) { - new_ct->mutable_inverse()->add_f_inverse(MapRef(f)); + ConvertSingleVarFormatToExpr( + f, new_ct->mutable_inverse()->add_f_expr_inverse()); } return true; } bool ModelCopy::CopyReservoir(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); new_ct->mutable_reservoir()->set_min_level(ct.reservoir().min_level()); new_ct->mutable_reservoir()->set_max_level(ct.reservoir().max_level()); @@ -1052,7 +1289,8 @@ bool ModelCopy::CopyReservoir(const ConstraintProto& ct) { new_ct->mutable_reservoir()->mutable_active_literals()->Reserve( ct.reservoir().active_literals_size()); for (const int lit : ct.reservoir().active_literals()) { - new_ct->mutable_reservoir()->add_active_literals(MapLiteral(lit)); + new_ct->mutable_reservoir()->add_active_literals( + MapLiteralEvenIfFixed(lit)); } return true; } @@ -1068,24 +1306,23 @@ bool ModelCopy::CopyAtMostOne(const ConstraintProto& ct) { int num_true = 0; temp_literals_.clear(); for (const int lit : ct.at_most_one().literals()) { - if (context_->LiteralIsFalse(lit)) continue; - if (context_->LiteralIsTrue(lit)) { - num_true++; + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) num_true++; continue; } - temp_literals_.push_back(MapRef(lit)); + temp_literals_.push_back(MapLiteral(lit)); } if (num_true > 1) return false; if (num_true == 1) { for (int lit : temp_literals_) { - if (!context_->SetLiteralToFalse(ReverseMapRef(lit))) return false; + if (!helper_.SetMappedLiteralToFalse(lit)) return false; } return true; } if (temp_literals_.size() <= 1) return true; - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); new_ct->mutable_at_most_one()->mutable_literals()->Add(temp_literals_.begin(), temp_literals_.end()); return true; @@ -1102,45 +1339,43 @@ bool ModelCopy::CopyExactlyOne(const ConstraintProto& ct) { int num_true = 0; temp_literals_.clear(); for (const int lit : ct.exactly_one().literals()) { - if (context_->LiteralIsFalse(lit)) continue; - if (context_->LiteralIsTrue(lit)) { - num_true++; + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) num_true++; continue; } - temp_literals_.push_back(MapRef(lit)); + temp_literals_.push_back(MapLiteral(lit)); } if (num_true > 1) return false; if (num_true == 1) { for (int lit : temp_literals_) { - if (!context_->SetLiteralToFalse(ReverseMapRef(lit))) return false; + if (!helper_.SetMappedLiteralToFalse(lit)) return false; } return true; } if (temp_literals_.empty()) return false; if (temp_literals_.size() == 1) { - return context_->SetLiteralToTrue(ReverseMapRef(temp_literals_[0])); + return helper_.SetMappedLiteralToTrue(temp_literals_[0]); } - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); new_ct->mutable_exactly_one()->mutable_literals()->Add(temp_literals_.begin(), temp_literals_.end()); return true; } bool ModelCopy::CopyBoolXor(const ConstraintProto& ct) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); int num_true = 0; temp_literals_.clear(); for (const int lit : ct.bool_xor().literals()) { - if (context_->LiteralIsFalse(lit)) continue; - if (context_->LiteralIsTrue(lit)) { - num_true++; + if (helper_.InputIsFixed(lit)) { + if (helper_.InputFixedLiteralIsTrue(lit)) num_true++; continue; } - temp_literals_.push_back(MapRef(lit)); + temp_literals_.push_back(MapLiteral(lit)); } if (num_true % 2 == 1) { temp_literals_.push_back(GetTrueMappedLiteral()); @@ -1155,8 +1390,8 @@ bool ModelCopy::CopyInterval(const ConstraintProto& ct, int c, CHECK_EQ(starting_constraint_index_, 0) << "Adding new interval constraints to partially filled model is not " "supported."; - interval_mapping_[c] = context_->working_model->constraints_size(); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + interval_mapping_[c] = working_model_->constraints_size(); + ConstraintProto* new_ct = working_model_->add_constraints(); if (!ignore_names) { new_ct->set_name(ct.name()); } @@ -1179,7 +1414,18 @@ bool ModelCopy::CopyInterval(const ConstraintProto& ct, int c, } bool ModelCopy::CopyIntProd(const ConstraintProto& ct, bool ignore_names) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + // Check if everything is fixed. + const auto fixed_linear_argument = + helper_.InputFixedLinearArgumentOrNullopt(ct.int_prod()); + if (fixed_linear_argument.has_value()) { + int64_t prod = 1; + for (const int64_t expr : fixed_linear_argument->exprs) { + prod = CapProd(prod, expr); + } + return CopyTrivialConstraint(prod == fixed_linear_argument->target); + } + + ConstraintProto* new_ct = working_model_->add_constraints(); if (!ignore_names) { new_ct->set_name(ct.name()); } @@ -1193,7 +1439,16 @@ bool ModelCopy::CopyIntProd(const ConstraintProto& ct, bool ignore_names) { } bool ModelCopy::CopyIntDiv(const ConstraintProto& ct, bool ignore_names) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + // Check if everything is fixed. + const auto fixed_linear_argument = + helper_.InputFixedLinearArgumentOrNullopt(ct.int_div()); + if (fixed_linear_argument.has_value()) { + return CopyTrivialConstraint(fixed_linear_argument->exprs[0] / + fixed_linear_argument->exprs[1] == + fixed_linear_argument->target); + } + + ConstraintProto* new_ct = working_model_->add_constraints(); if (!ignore_names) { new_ct->set_name(ct.name()); } @@ -1207,7 +1462,16 @@ bool ModelCopy::CopyIntDiv(const ConstraintProto& ct, bool ignore_names) { } bool ModelCopy::CopyIntMod(const ConstraintProto& ct, bool ignore_names) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + // Check if everything is fixed. + const auto fixed_linear_argument = + helper_.InputFixedLinearArgumentOrNullopt(ct.int_mod()); + if (fixed_linear_argument.has_value()) { + return CopyTrivialConstraint(fixed_linear_argument->exprs[0] % + fixed_linear_argument->exprs[1] == + fixed_linear_argument->target); + } + + ConstraintProto* new_ct = working_model_->add_constraints(); if (!ignore_names) { new_ct->set_name(ct.name()); } @@ -1247,15 +1511,19 @@ bool ModelCopy::AddLinearConstraintForInterval(const ConstraintProto& ct) { } // An enforced interval must have its size non-negative. + // + // Tricky: This is only called during first copy, so there is no mapping and + // we can use the MappedMinOf(). Alternatively we could look at the domain of + // the input cp_model_proto directly. + CHECK(variable_mapping_.empty()); const LinearExpressionProto& size_expr = itv.size(); - if (context_->MinOf(size_expr) < 0) { + if (helper_.MappedMinOf(size_expr) < 0) { tmp_constraint_.Clear(); *tmp_constraint_.mutable_enforcement_literal() = ct.enforcement_literal(); *tmp_constraint_.mutable_linear()->mutable_vars() = size_expr.vars(); *tmp_constraint_.mutable_linear()->mutable_coeffs() = size_expr.coeffs(); tmp_constraint_.mutable_linear()->add_domain(-size_expr.offset()); - tmp_constraint_.mutable_linear()->add_domain( - std::numeric_limits::max()); + tmp_constraint_.mutable_linear()->add_domain(kint64max); if (!CopyLinear(tmp_constraint_, true)) return false; } @@ -1272,17 +1540,20 @@ int ModelCopy::GetOrCreateVariableForConjunction(std::vector* literals) { std::sort(literals->begin(), literals->end()); auto it = boolean_product_encoding_.find(*literals); if (it != boolean_product_encoding_.end()) return it->second; - const int new_var = context_->NewBoolVarWithConjunction(*literals); + + const int new_var = helper_.NewIntVar(Domain(0, 1)); + helper_.solution_crush()->SetVarToConjunction(new_var, *literals); + boolean_product_encoding_[*literals] = new_var; // Add the constraint 'literals => new_var' - auto* ct1 = context_->working_model->add_constraints(); + auto* ct1 = working_model_->add_constraints(); ct1->mutable_bool_or()->mutable_literals()->Reserve(literals->size() + 1); for (const int literal : *literals) { ct1->mutable_bool_or()->add_literals(NegatedRef(literal)); } ct1->mutable_bool_or()->add_literals(new_var); // Add the constraint 'new_var => literals' - auto* ct2 = context_->working_model->add_constraints(); + auto* ct2 = working_model_->add_constraints(); ct2->add_enforcement_literal(new_var); *ct2->mutable_bool_and()->mutable_literals() = {literals->begin(), literals->end()}; @@ -1291,7 +1562,7 @@ int ModelCopy::GetOrCreateVariableForConjunction(std::vector* literals) { void ModelCopy::CopyAndMapNoOverlap(const ConstraintProto& ct) { // Note that we don't copy names here. - auto* new_ct = context_->working_model->add_constraints(); + auto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); NoOverlapConstraintProto* no_overlap = new_ct->mutable_no_overlap(); no_overlap->mutable_intervals()->Reserve(ct.no_overlap().intervals().size()); @@ -1305,7 +1576,7 @@ void ModelCopy::CopyAndMapNoOverlap(const ConstraintProto& ct) { void ModelCopy::CopyAndMapNoOverlap2D(const ConstraintProto& ct) { // Note that we don't copy names here. - auto* new_ct = context_->working_model->add_constraints(); + auto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); NoOverlap2DConstraintProto* no_overlap_2d = new_ct->mutable_no_overlap_2d(); const int num_intervals = ct.no_overlap_2d().x_intervals().size(); @@ -1322,24 +1593,14 @@ void ModelCopy::CopyAndMapNoOverlap2D(const ConstraintProto& ct) { } bool ModelCopy::CopyAndMapCumulative(const ConstraintProto& ct) { - if (ct.cumulative().intervals().empty() && - context_->IsFixed(ct.cumulative().capacity())) { + const std::optional fixed_capa = + helper_.InputFixedValueOrNullopt(ct.cumulative().capacity()); + if (ct.cumulative().intervals().empty() && fixed_capa != std::nullopt) { // Trivial constraint, either obviously SAT or UNSAT if enforced. - const int64_t capacity = context_->FixedValue(ct.cumulative().capacity()); - if (temp_enforcement_literals_.empty()) { - return capacity >= 0; - } - if (capacity < 0) { - // At least one enforcement literal must be false. - auto* new_ct = context_->working_model->add_constraints(); - for (const int literal : temp_enforcement_literals_) { - new_ct->mutable_bool_or()->add_literals(NegatedRef(literal)); - } - } - return true; + return CopyTrivialConstraint(fixed_capa.value() >= 0); } // Note that we don't copy names here. - auto* new_ct = context_->working_model->add_constraints(); + auto* new_ct = working_model_->add_constraints(); FinishEnforcementCopy(new_ct); CumulativeConstraintProto* cumulative = new_ct->mutable_cumulative(); CopyLinearExpression(ct.cumulative().capacity(), @@ -1361,15 +1622,20 @@ bool ModelCopy::CopyAndMapCumulative(const ConstraintProto& ct) { } bool ModelCopy::CopyObjective(const CpObjectiveProto& objective) { - non_fixed_terms_.clear(); - int64_t offset = 0; + int64_t offset = FillNonFixedTermsAndReturnOffset(objective); + CanonicalizeLinearExpression(/*enforcement_literals=*/nullptr, + non_fixed_terms_, offset); + + CpObjectiveProto& new_objective = *working_model_->mutable_objective(); + new_objective = objective; + new_objective.clear_vars(); + new_objective.clear_coeffs(); int64_t min_activity = 0; int64_t max_activity = 0; - for (int i = 0; i < objective.vars_size(); ++i) { - int ref = objective.vars(i); - int64_t coeff = objective.coeffs(i); - DCHECK(RefIsPositive(ref)); - const Domain& domain = context_->DomainOf(ref); + for (const auto [var, coeff] : non_fixed_terms_) { + new_objective.add_vars(var); + new_objective.add_coeffs(coeff); + const Domain& domain = helper_.MappedDomain(var); if (coeff > 0) { min_activity += coeff * domain.Min(); max_activity += coeff * domain.Max(); @@ -1377,30 +1643,14 @@ bool ModelCopy::CopyObjective(const CpObjectiveProto& objective) { min_activity += coeff * domain.Max(); max_activity += coeff * domain.Min(); } - MapTerm(ref, coeff, offset); - if (coeff == 0) continue; - DCHECK(RefIsPositive(ref)); - non_fixed_terms_.push_back({ref, coeff}); - } - CanonicalizeLinearExpression(/*enforcement_literals=*/nullptr, - non_fixed_terms_, offset); - - CpObjectiveProto& new_objective = - *context_->working_model->mutable_objective(); - new_objective = objective; - new_objective.clear_vars(); - new_objective.clear_coeffs(); - for (const auto [ref, coeff] : non_fixed_terms_) { - new_objective.add_vars(ref); - new_objective.add_coeffs(coeff); } new_objective.set_offset(new_objective.offset() + static_cast(offset)); if (objective.domain_size() > 0) { - Domain domain = ReadDomainFromProto(objective); + Domain domain = + ReadDomainFromProto(objective).AdditionWith(Domain(-offset)); domain = domain.IntersectionWith(Domain(min_activity, max_activity)); if (domain.IsEmpty()) return false; - domain = domain.AdditionWith(Domain(-offset)); FillDomainInProto(domain, &new_objective); } new_objective.set_integer_before_offset( @@ -1424,7 +1674,7 @@ void ModelCopy::CopyFloatingPointObjective( non_fixed_terms, offset); FloatObjectiveProto& new_objective = - *context_->working_model->mutable_floating_point_objective(); + *working_model_->mutable_floating_point_objective(); new_objective = objective; new_objective.clear_vars(); new_objective.clear_coeffs(); @@ -1435,56 +1685,9 @@ void ModelCopy::CopyFloatingPointObjective( new_objective.set_offset(new_objective.offset() + offset); } -void ModelCopy::CopySolutionHint(const PartialVariableAssignment& hint) { - PartialVariableAssignment& new_hint = - *context_->working_model->mutable_solution_hint(); - if (variable_mapping_.empty()) { - new_hint = hint; - } else { - std::vector hint_added(reverse_mapping_.size(), false); - for (int i = 0; i < hint.vars_size(); ++i) { - const int mapped_ref = variable_mapping_[hint.vars(i)]; - if (mapped_ref == kNoVariableMapping) continue; - const int mapped_var = PositiveRef(mapped_ref); - if (hint_added[mapped_var]) continue; - hint_added[mapped_var] = true; - new_hint.add_vars(mapped_var); - const int64_t hint_value = hint.values(i); - if (RefIsPositive(mapped_ref)) { - new_hint.add_values(hint.values(i)); - } else { - DCHECK(context_->CanBeUsedAsLiteral(hint.vars(i))); - DCHECK(hint_value == 0 || hint_value == 1); - new_hint.add_values(1 - hint_value); - } - } - } - - // We make sure the hint is within the variables domain. - // - // This allows to avoid overflow because we know evaluating constraints on - // the variables domains should be safe thanks to the initial validation. - const int num_terms = new_hint.vars().size(); - for (int i = 0; i < num_terms; ++i) { - const int var = new_hint.vars(i); - const int64_t value = new_hint.values(i); - const Domain& domain = MappedVarDomain(var); - if (domain.IsEmpty()) continue; // UNSAT. - const int64_t closest_domain_value = domain.ClosestValue(value); - if (closest_domain_value != value) { - context_->UpdateRuleStats("hint: moved var hint within its domain"); - new_hint.set_values(i, closest_domain_value); - } - } -} - bool ModelCopy::CreateUnsatModel(int c, const ConstraintProto& ct) { - context_->working_model->mutable_constraints()->Clear(); - context_->working_model->add_constraints()->mutable_bool_or(); - - // If the model was already marked as unsat, we keep the old message and just - // return. TODO(user): Append messages instead? - if (context_->ModelIsUnsat()) return false; + working_model_->mutable_constraints()->Clear(); + working_model_->add_constraints()->mutable_bool_or(); std::string proto_string; #if defined(ORTOOLS_TARGET_OS_SUPPORTS_PROTO_DESCRIPTOR) @@ -1498,24 +1701,23 @@ bool ModelCopy::CreateUnsatModel(int c, const ConstraintProto& ct) { std::string message = absl::StrCat( "proven during initial copy of constraint #", c, ":\n", proto_string); std::vector vars = UsedVariables(ct); - if (vars.size() < 10) { + if (vars.size() < 10 && variable_mapping_.empty()) { absl::StrAppend(&message, "With current variable domains:\n"); for (const int var : vars) { absl::StrAppend(&message, "var:", var, - " domain:", context_->DomainOf(var).ToString(), "\n"); + " domain:", helper_.MappedDomain(var).ToString(), "\n"); } } - return context_->NotifyThatModelIsUnsat(message); + + helper_.DisplaySummary(logger_); + SOLVER_LOG(logger_, "INFEASIBLE: '", message, "'"); + return false; } void ModelCopy::ExpandNonAffineExpressions() { - // Make sure all domains are initialized (they are used in - // MaybeExpandNonAffineExpression()). - context_->InitializeNewDomains(); - non_affine_expression_to_new_var_.clear(); - for (int c = 0; c < context_->working_model->constraints_size(); ++c) { - ConstraintProto* const ct = context_->working_model->mutable_constraints(c); + for (int c = 0; c < working_model_->constraints_size(); ++c) { + ConstraintProto* const ct = working_model_->mutable_constraints(c); switch (ct->constraint_case()) { case ConstraintProto::kIntDiv: MaybeExpandNonAffineExpressions(ct->mutable_int_div()); @@ -1569,6 +1771,16 @@ void ModelCopy::ExpandNonAffineExpressions() { MaybeExpandNonAffineExpression(&expr); } break; + case ConstraintProto::kInverse: + for (LinearExpressionProto& expr : + *ct->mutable_inverse()->mutable_f_expr_direct()) { + MaybeExpandNonAffineExpression(&expr); + } + for (LinearExpressionProto& expr : + *ct->mutable_inverse()->mutable_f_expr_inverse()) { + MaybeExpandNonAffineExpression(&expr); + } + break; case ConstraintProto::kAutomaton: for (LinearExpressionProto& expr : *ct->mutable_automaton()->mutable_exprs()) { @@ -1606,12 +1818,21 @@ void ModelCopy::MaybeExpandNonAffineExpression(LinearExpressionProto* expr) { gcd = std::gcd(gcd, std::abs(expr->coeffs(i))); } Domain domain(0); + int64_t domain_min = 0; + int64_t domain_max = 0; std::vector> definition; + definition.reserve(expr->vars().size()); for (int i = 0; i < expr->vars().size(); ++i) { const int var = expr->vars(i); const int64_t coeff = expr->coeffs(i) / gcd; - domain = - domain.AdditionWith(context_->DomainOf(var).MultiplicationBy(coeff)); + const Domain& domain = helper_.MappedDomain(var); // no mapping. + if (coeff > 0) { + domain_min += coeff * domain.Min(); + domain_max += coeff * domain.Max(); + } else { + domain_min += coeff * domain.Max(); + domain_max += coeff * domain.Min(); + } definition.push_back({var, coeff}); } std::sort(definition.begin(), definition.end()); @@ -1631,10 +1852,9 @@ void ModelCopy::MaybeExpandNonAffineExpression(LinearExpressionProto* expr) { new_var = it->second; gcd = -gcd; } else { - new_var = context_->NewIntVar(domain); + new_var = helper_.NewIntVar(Domain(domain_min, domain_max)); non_affine_expression_to_new_var_[definition] = new_var; - auto* new_linear = - context_->working_model->add_constraints()->mutable_linear(); + auto* new_linear = working_model_->add_constraints()->mutable_linear(); new_linear->add_vars(new_var); new_linear->add_coeffs(-1); for (const auto [var, coeff] : definition) { @@ -1643,7 +1863,7 @@ void ModelCopy::MaybeExpandNonAffineExpression(LinearExpressionProto* expr) { } new_linear->add_domain(0); new_linear->add_domain(0); - context_->solution_crush().SetVarToLinearExpression(new_var, definition); + helper_.solution_crush()->SetVarToLinearExpression(new_var, definition); } } expr->clear_vars(); @@ -1661,59 +1881,44 @@ void ModelCopy::MaybeExpandNonAffineExpressions( } template -void ModelCopy::MapTerm(int& ref, T& coeff, T& offset) const { - CHECK(RefIsPositive(ref)); - if (context_->IsFixed(ref)) { - offset += coeff * context_->MinOf(ref); +void ModelCopy::MapTerm(int& var, T& coeff, T& offset) const { + DCHECK(RefIsPositive(var)); + if (helper_.InputIsFixed(var)) { + offset += coeff * helper_.InputFixedValue(var); coeff = 0; return; } if (variable_mapping_.empty()) return; - const int mapped_ref = variable_mapping_[ref]; + const int mapped_ref = variable_mapping_[var]; if (RefIsPositive(mapped_ref)) { - ref = mapped_ref; + var = mapped_ref; } else { - // Only Boolean variables can be mapped to a negated ref. If x is mapped to + // Only Boolean variables can be mapped to a negated var. If x is mapped to // NegatedRef(y), then coeff * x = coeff * (1 - y). - DCHECK(context_->CanBeUsedAsLiteral(ref)); offset += coeff; coeff = -coeff; - ref = NegatedRef(mapped_ref); + var = NegatedRef(mapped_ref); } -} - -const Domain& ModelCopy::MappedVarDomain(int mapped_var) const { - if (variable_mapping_.empty()) return context_->DomainOf(mapped_var); - const int ref = reverse_mapping_[mapped_var]; - const Domain& domain = context_->DomainOf(PositiveRef(ref)); - // A variable can only be reverse mapped to a negative variable reference - // if it is a Boolean variable, in which case the remapped domain is - // unchanged (unless it is fixed). - if (RefIsPositive(ref) || !domain.IsFixed()) { - return domain; + const Domain mapped_domain = helper_.MappedDomain(var); + if (mapped_domain.IsFixed()) { + offset += coeff * mapped_domain.Min(); + coeff = 0; } - DCHECK(domain.IsIncludedIn(Domain(0, 1))); - return domain.Min() == 0 ? domain1_ : domain0_; } int ModelCopy::GetTrueMappedLiteral() { if (!true_mapped_literal_.has_value()) { true_mapped_literal_ = kNoVariableMapping; - if (variable_mapping_.empty()) { - for (int i = 0; i < context_->working_model->variables_size(); ++i) { - if (context_->CanBeUsedAsLiteral(i) && context_->IsFixed(i)) { - true_mapped_literal_ = context_->LiteralIsTrue(i) ? i : NegatedRef(i); - break; - } - } - } else { - for (int i = 0; i < reverse_mapping_.size(); ++i) { - const int ref = reverse_mapping_[i]; - if (context_->CanBeUsedAsLiteral(ref) && context_->IsFixed(ref)) { - true_mapped_literal_ = - context_->LiteralIsTrue(ref) ? i : NegatedRef(i); - break; - } + const int num_vars = helper_.MappedDomains().size(); + for (int i = 0; i < num_vars; ++i) { + const Domain& domain = helper_.MappedDomain(i); + if (domain.IsEmpty() || !domain.IsFixed()) continue; + if (domain.Min() == 1) { + true_mapped_literal_ = i; + break; + } else if (domain.Min() == 0) { + true_mapped_literal_ = NegatedRef(i); + break; } } DCHECK_NE(*true_mapped_literal_, kNoVariableMapping); @@ -1721,35 +1926,50 @@ int ModelCopy::GetTrueMappedLiteral() { return *true_mapped_literal_; } -bool ImportModelWithBasicPresolveIntoContext(const CpModelProto& in_model, - PresolveContext* context) { - ModelCopy copier(context); - copier.ImportVariablesAndMaybeIgnoreNames(in_model); - if (copier.ImportAndSimplifyConstraints(in_model, /*first_copy=*/true)) { - copier.ImportSolutionHint(in_model); - return copier.ImportEverythingExceptVariablesConstraintsAndHint(in_model); - } - return !context->ModelIsUnsat(); -} - -bool ImportModelAndDomainsWithBasicPresolveIntoContext( - const CpModelProto& in_model, absl::Span domains, - std::function active_constraints, PresolveContext* context, - std::vector* interval_mapping) { - CHECK_EQ(domains.size(), in_model.variables_size()); - ModelCopy copier(context); - copier.CreateVariablesFromDomains(domains); - if (copier.ImportAndSimplifyConstraints(in_model, /*first_copy=*/false, - active_constraints)) { - copier.ImportSolutionHint(in_model); - if (!copier.ImportEverythingExceptVariablesConstraintsAndHint(in_model)) { - return false; - } - interval_mapping->assign(copier.InternalIntervalMapping().begin(), - copier.InternalIntervalMapping().end()); - return true; +bool CopyModel(const CpModelProto& in_proto, CpModelProto* out_proto, + Model* model) { + out_proto->Clear(); + ModelCopy copier(out_proto, model); + if (!copier.ImportVariables(in_proto)) return false; + copier.ImportSolutionHint(in_proto); + if (!copier.ImportAndSimplifyConstraints(in_proto, /*first_copy=*/true)) { + return false; + } + if (!copier.ImportEverythingExceptVariablesConstraintsAndHint(in_proto)) { + return false; } - return !context->ModelIsUnsat(); + if (!copier.FinishCopy(in_proto)) { + return false; + } + + return true; +} + +bool CopyModelAdvanced(const CpModelProto& in_proto, + absl::Span domains, + std::function active_constraints, + std::vector* interval_mapping, + CpModelProto* out_proto, Model* model) { + out_proto->Clear(); + CHECK_EQ(domains.size(), in_proto.variables_size()); + + ModelCopy copier(out_proto, model); + if (!copier.CreateVariablesFromDomains(domains)) return false; + copier.ImportSolutionHint(in_proto); + if (!copier.ImportAndSimplifyConstraints(in_proto, /*first_copy=*/false, + active_constraints)) { + return false; + } + if (!copier.ImportEverythingExceptVariablesConstraintsAndHint(in_proto)) { + return false; + } + if (!copier.FinishCopy(in_proto)) { + return false; + } + + interval_mapping->assign(copier.InternalIntervalMapping().begin(), + copier.InternalIntervalMapping().end()); + return true; } void VariableDomains::Reset(int num_vars) { @@ -1915,8 +2135,7 @@ bool DenseModelCopy::ComputeVariableMapping( } new_input_var_mapping.assign(num_input_vars, kNoVariableMapping); - fixed_input_var_values_.resize(num_input_vars, - std::numeric_limits::min()); + fixed_input_var_values_.resize(num_input_vars, kint64min); reverse_mapping_.clear(); int first_fixed_literal = -1; for (int input_var = 0; input_var < num_input_vars; ++input_var) { @@ -1979,8 +2198,7 @@ bool DenseModelCopy::ApplyVariableMapping( for (int input_var = 0; input_var < num_input_vars; ++input_var) { const int current_ref = input_var_mapping_[input_var]; if (input_var_mapping[input_var] == kNoVariableMapping) { - DCHECK_NE(fixed_input_var_values_[input_var], - std::numeric_limits::min()); + DCHECK_NE(fixed_input_var_values_[input_var], kint64min); input_var_domains.push_back(Domain(fixed_input_var_values_[input_var])); continue; } @@ -1995,17 +2213,18 @@ bool DenseModelCopy::ApplyVariableMapping( } Model local_model; - CpModelProto mapping_proto; - auto context = std::make_unique(&local_model, &model_proto_, - &mapping_proto); - ModelCopy copier(context.get(), input_var_mapping, reverse_mapping_); - copier.CreateVariablesFromDomains(input_var_domains); + ModelCopy copier(&model_proto_, &local_model, input_var_mapping); + if (!copier.CreateVariablesFromDomains(input_var_domains)) { + return false; + } if (!copier.ImportAndSimplifyConstraints(input_model_proto_)) { return false; } - copier.ImportEverythingExceptVariablesConstraintsAndHint( - input_model_proto_, /*copy_symmetry=*/false); - if (!copier.RemapVariablesInProtoAndContext()) { + if (!copier.ImportEverythingExceptVariablesConstraintsAndHint( + input_model_proto_, /*copy_symmetry=*/false)) { + return false; + } + if (!copier.FinishCopy(input_model_proto_)) { return false; } ResetVarDomains(); @@ -2021,7 +2240,7 @@ void DenseModelCopy::ResetVarDomains() { } std::vector DenseModelCopy::MapSolution( - absl::Span input_solution) { + absl::Span input_solution) const { std::vector solution; solution.reserve(reverse_mapping_.size()); for (int var = 0; var < reverse_mapping_.size(); ++var) { @@ -2037,7 +2256,7 @@ std::vector DenseModelCopy::MapSolution( } std::vector DenseModelCopy::ReverseMapSolution( - absl::Span solution) { + absl::Span solution) const { const int num_input_vars = input_var_mapping_.size(); std::vector input_solution; input_solution.reserve(num_input_vars); @@ -2045,7 +2264,7 @@ std::vector DenseModelCopy::ReverseMapSolution( const int ref = input_var_mapping_[input_var]; if (ref == kNoVariableMapping) { input_solution.push_back(fixed_input_var_values_[input_var]); - DCHECK_NE(input_solution.back(), std::numeric_limits::min()); + DCHECK_NE(input_solution.back(), kint64min); } else { const int64_t value = solution[PositiveRef(ref)]; input_solution.push_back(RefIsPositive(ref) ? value : 1 - value); diff --git a/ortools/sat/cp_model_copy.h b/ortools/sat/cp_model_copy.h index be36123a8fc..771eee21d03 100644 --- a/ortools/sat/cp_model_copy.h +++ b/ortools/sat/cp_model_copy.h @@ -14,6 +14,7 @@ #ifndef ORTOOLS_SAT_CP_MODEL_COPY_H_ #define ORTOOLS_SAT_CP_MODEL_COPY_H_ +#include #include #include #include @@ -27,6 +28,7 @@ #include "absl/log/check.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/lrat_proof_handler.h" @@ -41,50 +43,167 @@ namespace sat { // A variable which is removed from the model during copy. // This can only be done if the variable is fixed. -constexpr int kNoVariableMapping = std::numeric_limits::min(); +constexpr int kNoVariableMapping = kint32min; -// This helper class perform copy with simplification from a model and a -// partial assignment to another model. The purpose is to minimize the size of -// the copied model, as well as to reduce the pressure on the memory sub-system. +// A simpler version of PresolveContext with just a few helper to query and +// manipulate variable domains during the copy and handle remapping. // -// It is currently used by the LNS part, but could be used with any other scheme -// that generates partial assignments. +// Only visible for testing. +class ModelCopyHelper { + public: + ModelCopyHelper() = default; + + // IMPORTANT: Must be called before any other functions in this class. + // + // Takes list of num_input_vars domains of the "input" cp_model proto before + // the mapping. The mapping can be empty or must be of the same same and will + // remap [0, num_input_vars) into a potential smaller [0, num_output_vars) + // dense space. + // + // More than one variable can be mapped to the same index, in which case these + // are assumed to be equivalent. And input variables with kNoVariableMapping + // must be fixed, that is InputIsFixed() should return true. + bool InitializeDomains(std::vector domains, + absl::Span mapping); + + // All Input*() functions accept variable indices in the input space. Note + // however that the "fixed" status might reflect changes we did in the mapped + // space as we perform the copy. + bool InputIsFixed(int ref) const; + bool InputFixedLiteralIsTrue(int ref) const; + int64_t InputFixedValue(int var) const; + std::optional InputFixedValueOrNullopt( + const LinearExpressionProto& expr) const; + + struct FixedLinearArgument { + int64_t target; + std::vector exprs; + }; + std::optional InputFixedLinearArgumentOrNullopt( + const LinearArgumentProto& linear_argument) const; + + // All the *Mapped*() functions work in the mapped space, after the mapping + // has been applied. All the mutable function are in this category. + ABSL_MUST_USE_RESULT bool IntersectMappedDomainWith(int var, + const Domain& domain); + + ABSL_MUST_USE_RESULT bool SetMappedLiteralToTrue(int ref) { + const int value = RefIsPositive(ref) ? 1 : 0; + return IntersectMappedDomainWith(PositiveRef(ref), Domain(value)); + } + + ABSL_MUST_USE_RESULT bool SetMappedLiteralToFalse(int ref) { + const int value = RefIsPositive(ref) ? 0 : 1; + return IntersectMappedDomainWith(PositiveRef(ref), Domain(value)); + } + + // The current variables domains in the mapped space. + absl::Span MappedDomains() const { return mapped_domains_; } + const Domain& MappedDomain(int var) const { return mapped_domains_[var]; } + int64_t MappedMinOf(const LinearExpressionProto& expr) const; + + // Create a new mapped variable. This is only allowed if there is no mapping. + int NewIntVar(const Domain& domain) { + UpdateRuleStats("new variable during copy"); + CHECK(mapping_.empty()); + const int index = mapped_domains_.size(); + mapped_domains_.push_back(domain); + solution_crush_.Resize(mapped_domains_.size()); + return index; + } + + // Used to track what happened during the copy. + // DisplaySummary() will list these statistics. + void UpdateRuleStats(std::string_view name, int num_times = 1) { + stats_by_rule_name_[name] += num_times; + } + void DisplaySummary(SolverLogger* logger) { + absl::btree_map sorted_rules(stats_by_rule_name_.begin(), + stats_by_rule_name_.end()); + for (const auto& entry : sorted_rules) { + if (entry.second == 1) { + SOLVER_LOG(logger, " - rule '", entry.first, "' was applied 1 time."); + } else { + SOLVER_LOG(logger, " - rule '", entry.first, "' was applied ", + FormatCounter(entry.second), " times."); + } + } + } + + // This is used to udate the solution hint as we create new variables. + SolutionCrush* solution_crush() { return &solution_crush_; } + + private: + // We track which of the variable in the input proto was or is now fixed. + // And for such fixed variable, its fixed value. + // + // This is mutable for speed as we can update that when a variable becomes + // fixed lazily. + mutable std::vector input_variable_is_fixed_; + mutable std::vector input_variable_fixed_values_; + + // Mapping from the input proto indexing to the new one. Fixed variable do not + // need to be mapped and can have a kNoVariableMapping entry. Note that + // Boolean might be mapped to negative reference. + std::vector mapping_; + + // The domain in the output variable index space. + std::vector mapped_domains_; + + // This is used temporarily to transfrom the hint during copy. + SolutionCrush solution_crush_; + + // Summary of the performed operations. + absl::flat_hash_map stats_by_rule_name_; +}; + +// This helper class performs copy with simplification from a CpModelProto and a +// partial assignment to another CpModelProto. The purpose is to minimize the +// size of the copied model, as well as to reduce the pressure on the memory +// sub-system. +// +// When first_copy is true, this is also responsible for canonicalizing the +// user-given model so that we don't have to handle all corner cases after this. class ModelCopy { public: // If `variable_mapping` is not empty, it is applied to all variable - // references in all the copied constraints. In this case, `context` must - // describe the variables before mapping. A fixed variable can be removed by - // setting its mapped value to `kNoVariableMapping`. If some Boolean variables - // are fixed, at least one of them must not be removed. A variable appearing - // in an InverseConstraintProto must not be removed. Several variables can be - // mapped to the same variable. As of 2025-03-25, non Boolean variables - // remapped to a negative variable reference are not supported. - // `variable_mapping` and `reverse_mapping` must remain valid and unchanged - // during the lifetime of the constructed instance. - explicit ModelCopy(PresolveContext* context, - absl::Span variable_mapping = {}, - absl::Span reverse_mapping = {}); - - // Copy variables from the in_model to the working model. It reads the - // 'ignore_names' parameters from the context, and keeps or deletes names - // accordingly. This must be done before importing constraints. The imported - // variables must be the ones before variable mapping, if any. They are not - // remapped in the context's working model (this must be done at the end, with - // RemapVariables()). - void ImportVariablesAndMaybeIgnoreNames(const CpModelProto& in_model); - - // Setup new variables from a vector of domains. This must be done before - // importing constraints. The imported variables must be the ones before - // variable mapping, if any. They are not remapped in the context's working - // model (this must be done at the end, with RemapVariables()). - void CreateVariablesFromDomains(absl::Span domains); + // references in all the copied constraints. + // + // Fixed variables can be removed by setting their mapped value to + // `kNoVariableMapping`. If some Boolean variables are fixed, at least one of + // them must not be removed. + // + // Several variables can be mapped to the same variable. Moreover, for Boolean + // only, we can map a Boolean to the negation of another by using a negated + // reference. + explicit ModelCopy(CpModelProto* out_proto, Model* model, + absl::Span variable_mapping = {}); + + // Setup new variables from the one in the given model. + // The imported variables must be the ones before variable mapping, if any. + // This must be done first before any other call here. Note that the actual + // IntegerVariableProto will only be written to the out_model in FinishCopy(). + // + // Returns false iff the model was infeasible. This can happen if we map + // two incompatible variables to the same one, or if some domain were empty. + ABSL_MUST_USE_RESULT bool ImportVariables(const CpModelProto& in_model); + + // Same as ImportVariables() but from a vector of Domain instead. + ABSL_MUST_USE_RESULT bool CreateVariablesFromDomains( + absl::Span domains); + + // This must be done before we import the constraints. The hint will be + // adapted if we ever create new variable as we canonicalize constraints. + // Note that the hint will only be written in FinishCopy(). + void ImportSolutionHint(const CpModelProto& in_model); // Copies all constraints from in_model to working model of the context. // // During the process, it will read variable domains from the context, and // simplify constraints to minimize the size of the copied model. // Thus it is important that the context->working_model already have the - // variables part copied. + // variables part copied. The hint must also be imported first so that it can + // be updated during the simplification. // // It returns false iff the model is proven infeasible. // @@ -93,23 +212,23 @@ class ModelCopy { // Note(user): If first_copy is true, we will reorder the scheduling // constraint so that they only use reference to previously defined intervals. // This allow to be more efficient later in a few preprocessing steps. - bool ImportAndSimplifyConstraints( + ABSL_MUST_USE_RESULT bool ImportAndSimplifyConstraints( const CpModelProto& in_model, bool first_copy = false, std::function active_constraints = nullptr); - bool ImportObjective(const CpModelProto& in_model); - void ImportSolutionHint(const CpModelProto& in_model); + // Imports and write the objective. + ABSL_MUST_USE_RESULT bool ImportObjective(const CpModelProto& in_model); // Copies the non constraint, non variables part of the model. `copy_symmetry` // is only supported if there is no variable mapping. - bool ImportEverythingExceptVariablesConstraintsAndHint( + ABSL_MUST_USE_RESULT bool ImportEverythingExceptVariablesConstraintsAndHint( const CpModelProto& in_model, bool copy_symmetry = true); - // Remaps all variables in the context's working model using the variable - // mapping passed at construction time. This must be done after all - // constraints have been imported. - // Returns false iff the model is proven infeasible. - bool RemapVariablesInProtoAndContext(); + // This must be called to finish the copy. Note that variables proto will only + // be filled at this stage in the out_model. + // + // Returns false iff the model was infeasible. + ABSL_MUST_USE_RESULT bool FinishCopy(const CpModelProto& in_model); // Advanced usage. When a model was copied, interval_mapping[i] will // contain for a copied interval with original index i, its new index. @@ -118,6 +237,9 @@ class ModelCopy { } private: + // Write a summary of what happen to the logger_. + void DisplaySummary(); + // Overwrites the out_model to be unsat. Returns false. // The arguments are used to log which constraint caused unsat. bool CreateUnsatModel(int c, const ConstraintProto& ct); @@ -128,6 +250,18 @@ class ModelCopy { void FinishEnforcementCopy(ConstraintProto* ct); // All these functions return false if the constraint is found infeasible. + + // Copy a constraint that is always false, returning false if the enforcement + // literals are empty or creating a constraint forcing at least one of them to + // be false. + bool CopyFalseConstraint(); + + // Copy a constraint that is either always or never satisfied. + bool CopyTrivialConstraint(bool is_always_satisfied) { + if (is_always_satisfied) return true; + return CopyFalseConstraint(); + } + bool CopyBoolOr(const ConstraintProto& ct); bool CopyBoolOrWithDupSupport(const ConstraintProto& ct, int one_based_cnf_index); @@ -151,7 +285,13 @@ class ModelCopy { template void CanonicalizeLinearExpression( const absl::flat_hash_set* enforcement_literals, - std::vector>& terms, T& offset) const; + std::vector>& terms, T& offset); + + // This fills the non_fixed_terms_ field. + template + int64_t FillNonFixedTermsAndReturnOffset(const T& proto_with_vars_and_coeffs, + int64_t offset = 0); + bool CopyAutomaton(const ConstraintProto& ct); bool CopyTable(const ConstraintProto& ct); bool CopyAllDiff(const ConstraintProto& ct); @@ -179,7 +319,6 @@ class ModelCopy { bool CopyObjective(const CpObjectiveProto& objective); void CopyFloatingPointObjective(const FloatObjectiveProto& objective); - void CopySolutionHint(const PartialVariableAssignment& hint); // Expands linear expressions with more than one variable in constraints which // internally only support affine expressions (such as all_diff, element, @@ -192,28 +331,24 @@ class ModelCopy { void MaybeExpandNonAffineExpression(LinearExpressionProto* expr); void MaybeExpandNonAffineExpressions(LinearArgumentProto* linear_argument); - int MapLiteral(int lit) { + int MapLiteralEvenIfFixed(int lit) { if (variable_mapping_.empty()) return lit; - if (context_->IsFixed(lit)) { + if (helper_.InputIsFixed(lit)) { const int true_mapped_lit = GetTrueMappedLiteral(); - return context_->LiteralIsTrue(lit) ? true_mapped_lit - : NegatedRef(true_mapped_lit); + return helper_.InputFixedLiteralIsTrue(lit) ? true_mapped_lit + : NegatedRef(true_mapped_lit); } - return MapRef(lit); + return MapLiteral(lit); } - // `ref` must not have a fixed value, otherwise it might have no mapping. - int MapRef(int ref) const { - if (variable_mapping_.empty()) return ref; - const int mapped_ref = variable_mapping_[PositiveRef(ref)]; + // `lit` must not have a fixed value, otherwise it might have no mapping. + // Use MapLiteralEvenIfFixed() if unsure. + int MapLiteral(int lit) const { + DCHECK(!helper_.InputIsFixed(lit)); + if (variable_mapping_.empty()) return lit; + const int mapped_ref = variable_mapping_[PositiveRef(lit)]; DCHECK_NE(mapped_ref, kNoVariableMapping); - return RefIsPositive(ref) ? mapped_ref : NegatedRef(mapped_ref); - } - - int ReverseMapRef(int mapped_ref) const { - if (variable_mapping_.empty()) return mapped_ref; - const int ref = reverse_mapping_[PositiveRef(mapped_ref)]; - return RefIsPositive(mapped_ref) ? ref : NegatedRef(ref); + return RefIsPositive(lit) ? mapped_ref : NegatedRef(mapped_ref); } // Normalizes `ref` to a positive reference, replaces fixed terms with an @@ -222,15 +357,19 @@ class ModelCopy { template void MapTerm(int& ref, T& coeff, T& offset) const; - // Returns the domain of `mapped_var`, computed from the domain of the - // original variable mapped to `mapped_var` (as stored in the context). - const Domain& MappedVarDomain(int mapped_var) const; + // Helper to convert old proto format (single var) int new + // LinearExpressionProto format. + void ConvertSingleVarFormatToExpr(int var, LinearExpressionProto* expr); int GetTrueMappedLiteral(); - PresolveContext* context_; + ModelCopyHelper helper_; + const SatParameters& params_; + SolverLogger* logger_; + CpModelProto* working_model_; + absl::Span variable_mapping_; - absl::Span reverse_mapping_; + // If some original Boolean variables are fixed at least one of them must not // be removed by the variable mapping, from which we compute this always true // mapped literal. @@ -262,12 +401,9 @@ class ModelCopy { // variables for the identical non affine expressions. absl::flat_hash_map>, int> non_affine_expression_to_new_var_; - - const Domain domain0_ = Domain(0); - const Domain domain1_ = Domain(1); }; -// Copy in_model to the model in the presolve context. +// Copy in_proto to out_proto. // It performs on the fly simplification, and returns false if the // model is proved infeasible. If reads the parameters 'ignore_names' and keeps // or deletes variables and constraints names accordingly. @@ -275,15 +411,16 @@ class ModelCopy { // This should only be called on the first copy of the user given model. // Note that this reorder all constraints that use intervals last. We loose the // user-defined order, but hopefully that should not matter too much. -bool ImportModelWithBasicPresolveIntoContext(const CpModelProto& in_model, - PresolveContext* context); +bool CopyModel(const CpModelProto& in_proto, CpModelProto* out_proto, + Model* model); -// Same as ImportModelWithBasicPresolveIntoContext() except that variable -// domains are read from domains and constraint might be filtered. -bool ImportModelAndDomainsWithBasicPresolveIntoContext( - const CpModelProto& in_model, absl::Span domains, - std::function active_constraints, PresolveContext* context, - std::vector* interval_mapping); +// Same as CopyModel() except that variable domains are read from domains and +// constraint might be filtered. +bool CopyModelAdvanced(const CpModelProto& in_proto, + absl::Span domains, + std::function active_constraints, + std::vector* interval_mapping, + CpModelProto* out_proto, Model* model); // Accessing Domain can be expensive, so we maintain vector of bool for the // hot spots. @@ -351,8 +488,10 @@ class DenseModelCopy { bool MaybeUpdate(bool& updated); // Maps a solution from the input model to the dense model, and vice versa. - std::vector MapSolution(absl::Span input_solution); - std::vector ReverseMapSolution(absl::Span solution); + std::vector MapSolution( + absl::Span input_solution) const; + std::vector ReverseMapSolution( + absl::Span solution) const; private: // Computes and applies a dense mapping of the variables which removes fixed diff --git a/ortools/sat/cp_model_copy_test.cc b/ortools/sat/cp_model_copy_test.cc index 074be06ce0d..76af0146e77 100644 --- a/ortools/sat/cp_model_copy_test.cc +++ b/ortools/sat/cp_model_copy_test.cc @@ -52,9 +52,7 @@ TEST(ModelCopyTest, IntervalsAddLinearConstraints) { Model model; CpModelProto new_cp_model; - PresolveContext context(&model, &new_cp_model, nullptr); - - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); const CpModelProto expected_model = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } variables { domain: [ -10, 10 ] } @@ -105,9 +103,7 @@ TEST(ModelCopyTest, IntervalsWithFixedStartAndEnd) { Model model; CpModelProto new_cp_model; - PresolveContext context(&model, &new_cp_model, nullptr); - - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); const CpModelProto expected_model = ParseTestProto(R"pb( variables { domain: [ 10, 10 ] } variables { domain: [ 10, 10 ] } @@ -155,8 +151,7 @@ TEST(ModelCopyTest, RemoveDuplicateFromClauses) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } @@ -207,8 +202,7 @@ TEST(ModelCopyTest, RemoveDuplicateFromEnforcementLiterals) { Model model; model.GetOrCreate() ->set_keep_all_feasible_solutions_in_presolve(true); - PresolveContext context(&model, &new_cp_model, nullptr); - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } @@ -246,15 +240,14 @@ TEST(ModelCopyTest, RemapLiteralsInBoolOr) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); const std::vector variable_mapping = { kNoVariableMapping, 1, -3, -3, 1, 0, kNoVariableMapping}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); const CpModelProto expected_moded = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } @@ -284,15 +277,14 @@ TEST(ModelCopyTest, RemapLiteralsInBoolAnd) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); const std::vector variable_mapping = { kNoVariableMapping, 0, 1, -2, 0, -2}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); const CpModelProto expected_moded = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } @@ -319,15 +311,14 @@ TEST(ModelCopyTest, RemapLiteralsInBoolXor) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); const std::vector variable_mapping = { kNoVariableMapping, kNoVariableMapping, 0, 1, 2, 3}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); const CpModelProto expected_moded = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } @@ -358,19 +349,18 @@ TEST(ModelCopyTest, RemapVariablesInLinear) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); const std::vector variable_mapping = {kNoVariableMapping, -1, 1, 2, 2}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); const CpModelProto expected_moded = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } variables { domain: [ 0, 10 ] } - variables { domain: [ 0, 11 ] } + variables { domain: [ 0, 10 ] } constraints { enforcement_literal: [ -1 ] linear { @@ -397,14 +387,13 @@ TEST(ModelCopyTest, RemapVariablesInLinear_CanonicalizeSingleBoolInDomain) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); const std::vector variable_mapping = {0, -1}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); // 5x + 6y \in [6, 7] is remapped to 5x + 6.not(x) = 6 - x \in [6, 7], whose // unique solution is x = 0. @@ -414,6 +403,151 @@ TEST(ModelCopyTest, RemapVariablesInLinear_CanonicalizeSingleBoolInDomain) { EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } +TEST(ModelCopyTest, RemapVariablesInLegacyInverse) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 2, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 2, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + inverse { + f_direct: 0 + f_direct: 1 + f_direct: 2 + f_inverse: 3 + f_inverse: 4 + f_inverse: 5 + } + } + )pb"); + CpModelProto new_cp_model; + Model model; + const std::vector variable_mapping = {kNoVariableMapping, 0, 1, 1, + kNoVariableMapping, 0}; + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); + + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); + + const CpModelProto expected_moded = ParseTestProto(R"pb( + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + inverse { + f_expr_direct: { offset: 2 } + f_expr_direct: { vars: 0 coeffs: 1 } + f_expr_direct: { vars: 1 coeffs: 1 } + f_expr_inverse: { vars: 1 coeffs: 1 } + f_expr_inverse: { offset: 2 } + f_expr_inverse: { vars: 0 coeffs: 1 } + } + } + )pb"); + EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); +} + +TEST(ModelCopyTest, RemapVariablesInInverse) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 1, 1 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + inverse { + f_expr_direct: { vars: 0 coeffs: -1 offset: 3 } + f_expr_direct: { vars: 1 coeffs: 1 offset: -1 } + f_expr_direct: { vars: 2 coeffs: 3 offset: -2 } + f_expr_inverse: { vars: 0 coeffs: 1 offset: 0 } + f_expr_inverse: { vars: 1 coeffs: 1 offset: 1 } + f_expr_inverse: { vars: 2 coeffs: -2 offset: 2 } + } + } + )pb"); + CpModelProto new_cp_model; + Model model; + const std::vector variable_mapping = {kNoVariableMapping, 0, 1}; + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); + + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); + + const CpModelProto expected_moded = ParseTestProto(R"pb( + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + inverse { + f_expr_direct: { offset: 2 } + f_expr_direct: { vars: 0 coeffs: 1 offset: -1 } + f_expr_direct: { vars: 1 coeffs: 3 offset: -2 } + f_expr_inverse: { offset: 1 } + f_expr_inverse: { vars: 0 coeffs: 1 offset: 1 } + f_expr_inverse: { vars: 1 coeffs: -2 offset: 2 } + } + } + )pb"); + EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); +} + +TEST(ModelCopyTest, ExpandNonAffineExpressionsInInverse) { + const CpModelProto initial_model = ParseTestProto(R"pb( + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + constraints { + inverse { + f_expr_direct: { vars: 0 coeffs: -1 offset: 3 } + f_expr_direct: { vars: 1 coeffs: 1 offset: -1 } + f_expr_direct: { vars: 2 coeffs: 3 offset: -2 } + f_expr_inverse: { vars: 0 coeffs: 1 offset: 0 } + f_expr_inverse: { vars: 1 coeffs: 1 offset: 1 } + f_expr_inverse: { + vars: [ 1, 2 ] + coeffs: [ -1, -1 ] + offset: 2 + } + } + } + )pb"); + CpModelProto new_cp_model; + Model model; + ModelCopy model_copy(&new_cp_model, &model); + + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); + + const CpModelProto expected_moded = ParseTestProto(R"pb( + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ 0, 2 ] } + variables { domain: [ -4, 0 ] } + constraints { + inverse { + f_expr_direct: { vars: 0 coeffs: -1 offset: 3 } + f_expr_direct: { vars: 1 coeffs: 1 offset: -1 } + f_expr_direct: { vars: 2 coeffs: 3 offset: -2 } + f_expr_inverse: { vars: 0 coeffs: 1 offset: 0 } + f_expr_inverse: { vars: 1 coeffs: 1 offset: 1 } + f_expr_inverse: { vars: 3 coeffs: 1 offset: 2 } + } + } + constraints { + linear { + vars: [ 3, 1, 2 ] + coeffs: [ -1, -1, -1 ] + domain: [ 0, 0 ] + } + } + )pb"); + EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); +} + TEST(ModelCopyTest, RemapVariablesInObjective) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } @@ -426,15 +560,15 @@ TEST(ModelCopyTest, RemapVariablesInObjective) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); const std::vector variable_mapping = {0, -1}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); - model_copy.ImportEverythingExceptVariablesConstraintsAndHint(initial_model); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.ImportEverythingExceptVariablesConstraintsAndHint( + initial_model)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); // 5x + 6y \in [0, 50] is remapped to 5x + 6.not(x) = 6 - x \in [0, 50]. const CpModelProto expected_moded = ParseTestProto(R"pb( @@ -443,7 +577,7 @@ TEST(ModelCopyTest, RemapVariablesInObjective) { vars: [ 0 ] coeffs: [ -1 ] offset: 6 - domain: [ -6, 5 ] + domain: [ -1, 0 ] integer_before_offset: 6 } )pb"); @@ -462,15 +596,15 @@ TEST(ModelCopyTest, RemapVariablesInFloatingPointObjective) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); const std::vector variable_mapping = {0, -1}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); - model_copy.ImportEverythingExceptVariablesConstraintsAndHint(initial_model); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); + EXPECT_TRUE(model_copy.ImportEverythingExceptVariablesConstraintsAndHint( + initial_model)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); // 5.5x + 6.25y + 2.75 is remapped to 5.5x + 6.25 * (1 - x) + 2.75. const CpModelProto expected_moded = ParseTestProto(R"pb( @@ -497,17 +631,17 @@ TEST(ModelCopyTest, RemapVariablesInSearchStrategyAssumptionsAndHint) { )pb"); CpModelProto new_cp_model; Model model; - PresolveContext context(&model, &new_cp_model, nullptr); // x0, x1 mapped to x0, not(x0). const std::vector variable_mapping = {0, -1}; - const std::vector reverse_mapping = ReverseMapping(variable_mapping); - ModelCopy model_copy(&context, variable_mapping, reverse_mapping); + ModelCopy model_copy(&new_cp_model, &model, variable_mapping); - model_copy.ImportVariablesAndMaybeIgnoreNames(initial_model); - model_copy.ImportAndSimplifyConstraints(initial_model, /*first_copy=*/true); + EXPECT_TRUE(model_copy.ImportVariables(initial_model)); + EXPECT_TRUE(model_copy.ImportAndSimplifyConstraints(initial_model, + /*first_copy=*/true)); model_copy.ImportSolutionHint(initial_model); - model_copy.ImportEverythingExceptVariablesConstraintsAndHint(initial_model); - model_copy.RemapVariablesInProtoAndContext(); + EXPECT_TRUE(model_copy.ImportEverythingExceptVariablesConstraintsAndHint( + initial_model)); + EXPECT_TRUE(model_copy.FinishCopy(initial_model)); const CpModelProto expected_moded = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } @@ -565,8 +699,7 @@ TEST(ModelCopyTest, ChangeEnforcedAtMostOrExactlyOneToLinear) { Model model; model.GetOrCreate() ->set_keep_all_feasible_solutions_in_presolve(true); - PresolveContext context(&model, &new_cp_model, nullptr); - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } @@ -604,8 +737,7 @@ TEST(ModelCopyTest, LegacyElementConstraint) { Model model; model.GetOrCreate() ->set_keep_all_feasible_solutions_in_presolve(true); - PresolveContext context(&model, &new_cp_model, nullptr); - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } @@ -644,8 +776,7 @@ TEST(ModelCopyTest, ElementConstraint) { Model model; model.GetOrCreate() ->set_keep_all_feasible_solutions_in_presolve(true); - PresolveContext context(&model, &new_cp_model, nullptr); - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } @@ -696,8 +827,7 @@ TEST(ModelCopyTest, ExpandedNonAffineExpressionsShareVariableWhenPossible) { Model model; model.GetOrCreate() ->set_keep_all_feasible_solutions_in_presolve(true); - PresolveContext context(&model, &new_cp_model, nullptr); - ImportModelWithBasicPresolveIntoContext(initial_model, &context); + CopyModel(initial_model, &new_cp_model, &model); EXPECT_THAT(new_cp_model, EqualsProto(expected_moded)); } diff --git a/ortools/sat/cp_model_expand.cc b/ortools/sat/cp_model_expand.cc index d415c738c49..ef8caf5e598 100644 --- a/ortools/sat/cp_model_expand.cc +++ b/ortools/sat/cp_model_expand.cc @@ -32,6 +32,7 @@ #include "absl/types/span.h" #include "google/protobuf/message.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" @@ -52,8 +53,7 @@ void ExpandAlwaysFalseConstraint(ConstraintProto* ct, PresolveContext* context, if (ct->enforcement_literal().empty()) { return (void)context->NotifyThatModelIsUnsat(message); } - BoolArgumentProto& bool_or = - *context->working_model->add_constraints()->mutable_bool_or(); + BoolArgumentProto& bool_or = *context->AddConstraint()->mutable_bool_or(); for (const int literal : ct->enforcement_literal()) { bool_or.add_literals(NegatedRef(literal)); } @@ -212,8 +212,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, // The encoding will create a circuit constraint, and one integer variable per // event (representing the level at that event time). - CircuitConstraintProto* circuit = - context->working_model->add_constraints()->mutable_circuit(); + CircuitConstraintProto* circuit = context->AddConstraint()->mutable_circuit(); const int64_t var_min = std::max(reservoir.min_level(), sum_of_negative_demand); @@ -293,7 +292,7 @@ void ExpandReservoirUsingCircuit(int64_t sum_of_positive_demand, ConstraintProto* new_ct = context->AddEnforcedConstraint(reservoir_ct); new_ct->add_enforcement_literal(arc_i_j); LinearConstraintProto* lin = new_ct->mutable_linear(); - FillDomainInProto(0, std::numeric_limits::max(), lin); + FillDomainInProto(0, kint64max, lin); AddLinearExpressionToLinearConstraint(reservoir.time_exprs(j), 1, lin); AddLinearExpressionToLinearConstraint(reservoir.time_exprs(i), -1, lin); context->CanonicalizeLinearConstraint(new_ct); @@ -384,17 +383,15 @@ void ExpandReservoirUsingPrecedences(bool max_level_is_constraining, // Note that according to the sign of demand_i, we only need one side. // We apply the offset here to make sure we use int64_t min and max. if (demand_i > 0) { - FillDomainInProto(std::numeric_limits::min(), - reservoir.max_level() - offset, new_linear); + FillDomainInProto(kint64min, reservoir.max_level() - offset, new_linear); } else { - FillDomainInProto(reservoir.min_level() - offset, - std::numeric_limits::max(), new_linear); + FillDomainInProto(reservoir.min_level() - offset, kint64max, new_linear); } // Canonicalize the newly created constraint. context->CanonicalizeLinearConstraint(new_cumul); - DCHECK(!PossibleIntegerOverflow(*context->working_model, new_linear->vars(), + DCHECK(!PossibleIntegerOverflow(context->WorkingModel(), new_linear->vars(), new_linear->coeffs())); } @@ -534,14 +531,14 @@ void EncodeCumulativeAsReservoir(ConstraintProto* ct, ConstraintProto reservoir_ct; *reservoir_ct.mutable_enforcement_literal() = ct->enforcement_literal(); auto* reservoir = reservoir_ct.mutable_reservoir(); - reservoir->set_min_level(std::numeric_limits::min()); + reservoir->set_min_level(kint64min); reservoir->set_max_level(context->FixedValue(ct->cumulative().capacity())); const int true_literal = context->GetTrueLiteral(); const int num_intervals = ct->cumulative().intervals().size(); for (int i = 0; i < num_intervals; ++i) { const auto& interval_ct = - context->working_model->constraints(ct->cumulative().intervals(i)); + context->Constraint(ct->cumulative().intervals(i)); const auto& interval = interval_ct.interval(); *reservoir->add_time_exprs() = interval.start(); *reservoir->add_time_exprs() = interval.end(); @@ -676,52 +673,35 @@ void ExpandIntProd(ConstraintProto* ct, PresolveContext* context) { } void ExpandInverse(ConstraintProto* ct, PresolveContext* context) { - const auto& f_direct = ct->inverse().f_direct(); - const auto& f_inverse = ct->inverse().f_inverse(); - const int n = f_direct.size(); - CHECK_EQ(n, f_inverse.size()); - - // Make sure the domains are included in [0, n - 1). - // Note that if a variable and its negation appear, the domains will be set to - // zero here. + DCHECK(ct->inverse().f_direct().empty()); + const auto& f_expr_direct = ct->inverse().f_expr_direct(); + const auto& f_expr_inverse = ct->inverse().f_expr_inverse(); + const int n = f_expr_direct.size(); + CHECK_EQ(n, f_expr_inverse.size()); + + // Make sure the variable domains give expression domains in [0, n - 1]. + // Note that if a variable appears in several expressions, its domain will be + // the intersection of the domains implied by each expression. In particular, + // if there are two expressions "v" and "-v", the domain of v will be {0}. // // TODO(user): Add support for UNSAT at expansion. This should create empty // domain if UNSAT, so it should still work correctly. EnforcedDomains enforced_domains(ct, context); - for (const int ref : f_direct) { + for (const LinearExpressionProto& expr : f_expr_direct) { if (!enforced_domains.IntersectDomainWith( - ref, Domain(0, n - 1), + expr, Domain(0, n - 1), "Empty domain for a variable in ExpandInverse()")) { return; } } - for (const int ref : f_inverse) { + for (const LinearExpressionProto& expr : f_expr_inverse) { if (!enforced_domains.IntersectDomainWith( - ref, Domain(0, n - 1), + expr, Domain(0, n - 1), "Empty domain for a variable in ExpandInverse()")) { return; } } - // If we have duplicate variables, we make sure the domain are reduced - // as the loop below might not detect incompatibilities. - if (enforced_domains.size() != 2 * n) { - for (int i = 0; i < n; ++i) { - for (int j = 0; j < n; ++j) { - // Note that if we don't have the same sign, both domain are at zero. - if (PositiveRef(f_direct[i]) != PositiveRef(f_inverse[j])) continue; - - // We can't have i or j as value if i != j. - if (i == j) continue; - if (!enforced_domains.IntersectDomainWith( - f_direct[i], Domain::FromValues({i, j}).Complement(), - "Empty domain for a variable in ExpandInverse()")) { - return; - } - } - } - } - // Reduce the domains of each variable by checking that the inverse value // exists. std::vector possible_values; @@ -729,65 +709,103 @@ void ExpandInverse(ConstraintProto* ct, PresolveContext* context) { // Propagate from one vector to its counterpart. const auto filter_inverse_domain = [&enforced_domains, n, &possible_values]( const auto& direct, - const auto& inverse) { + const auto& inverse, bool& changed) { // Propagate from the inverse vector to the direct vector. for (int i = 0; i < n; ++i) { - possible_values.clear(); - const Domain domain = enforced_domains.DomainOf(direct[i]); - bool removed_value = false; - for (const int64_t j : domain.Values()) { - if (enforced_domains.DomainOf(inverse[j]).Contains(i)) { - possible_values.push_back(j); - } else { - removed_value = true; - } - } - if (removed_value) { + if (direct[i].vars().empty()) { + const int64_t j = direct[i].offset(); if (!enforced_domains.IntersectDomainWith( - direct[i], Domain::FromValues(possible_values), - "Empty domain for a variable in ExpandInverse()")) { + inverse[j], Domain(i), + "Empty domain for a variable in ExpandInverse()", &changed)) { return false; } + } else { + possible_values.clear(); + bool removed_value = false; + for (const int64_t i_var_value : + enforced_domains.DomainOf(direct[i].vars(0)).Values()) { + const int64_t j = + direct[i].coeffs(0) * i_var_value + direct[i].offset(); + if (enforced_domains.DomainContains(inverse[j], i)) { + possible_values.push_back(j); + } else { + removed_value = true; + } + } + if (removed_value) { + changed = true; + if (!enforced_domains.IntersectDomainWith( + direct[i], Domain::FromValues(possible_values), + "Empty domain for a variable in ExpandInverse()")) { + return false; + } + } } } return true; }; - // Note that this should reach the fixed point in one pass. - // However, if we have duplicate variable, I am not sure. - if (!filter_inverse_domain(f_direct, f_inverse)) return; - if (!filter_inverse_domain(f_inverse, f_direct)) return; + // Loop until fix point (several iterations might be needed if some + // expressions share the same variable). + for (int i = 0; i < 100; ++i) { // Just to avoid potentially bad cases. + bool changed = false; + if (!filter_inverse_domain(f_expr_direct, f_expr_inverse, changed)) return; + if (!filter_inverse_domain(f_expr_inverse, f_expr_direct, changed)) return; + if (!changed) break; + } enforced_domains.MaybeAddEnforcedDomainConstraints(); // Expand the inverse constraint by associating literal to var == value // and sharing them between the direct and inverse variables. // - // Note that this is only correct because the domain are tight now. + // Note that this is only correct because the domains are tight now. for (int i = 0; i < n; ++i) { - const int f_i = f_direct[i]; - for (const int64_t j : enforced_domains.DomainOf(f_i).Values()) { - const int r_j = f_inverse[j]; + const LinearExpressionProto& f_i = f_expr_direct[i]; + // Fixed expressions are already handled by the propagation above. + if (f_i.vars().empty()) continue; + const int f_i_var = f_i.vars(0); + const int64_t f_i_coeff = f_i.coeffs(0); + const int64_t f_i_offset = f_i.offset(); + for (const int64_t f_i_var_value : + enforced_domains.DomainOf(f_i_var).Values()) { + const int64_t j = f_i_coeff * f_i_var_value + f_i_offset; + const LinearExpressionProto& r_j = f_expr_inverse[j]; + if (r_j.vars().empty()) continue; + const int r_j_var = r_j.vars(0); + const int64_t r_j_coeff = r_j.coeffs(0); + const int64_t r_j_offset = r_j.offset(); + // We have f_i_var * f_i_coeff + f_i_offset == j <=> + // r_j_var * r_j_coeff + r_j_offset == i + // which is equivalent to + // f_i_var == f_i_var_value <=> + // r_j_var == r_j_var_value = (i - r_j_offset) / r_j_coeff + DCHECK_EQ((i - r_j_offset) % r_j_coeff, 0); + const int64_t r_j_var_value = (i - r_j_offset) / r_j_coeff; if (ct->enforcement_literal().empty()) { - // We have f[i] == j <=> r[j] == i; int r_j_i; if (enforced_domains.DomainContains(r_j, i) && - context->VarCanTakeValue(r_j, i) && - context->HasVarValueEncoding(r_j, i, &r_j_i)) { - if (!context->InsertVarValueEncoding(r_j_i, f_i, j)) { + context->VarCanTakeValue(r_j_var, r_j_var_value) && + context->HasVarValueEncoding(r_j_var, r_j_var_value, &r_j_i)) { + if (!context->InsertVarValueEncoding(r_j_i, f_i_var, f_i_var_value)) { return; } } else { - const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j); - if (!context->InsertVarValueEncoding(f_i_j, r_j, i)) { + const int f_i_j = + context->GetOrCreateVarValueEncoding(f_i_var, f_i_var_value); + if (!context->InsertVarValueEncoding(f_i_j, r_j_var, r_j_var_value)) { return; } } } else { - // We have enforcement_literal && f[i] == j => r[j] == i; - // We have enforcement_literal && r[j] == i => f[i] == j; - const int f_i_j = context->GetOrCreateVarValueEncoding(f_i, j); - const int r_j_i = context->GetOrCreateVarValueEncoding(r_j, i); + // enforcement_literal && f_i_var == f_i_var_value => + // r_j_var == r_j_var_value + // enforcement_literal && r_j_var == r_j_var_value => + // f_i_var == f_i_var_value + const int f_i_j = + context->GetOrCreateVarValueEncoding(f_i_var, f_i_var_value); + const int r_j_i = + context->GetOrCreateVarValueEncoding(r_j_var, r_j_var_value); if (f_i_j != r_j_i) { ConstraintProto* eq_direct = context->AddEnforcedConstraint(ct); eq_direct->add_enforcement_literal(f_i_j); @@ -820,7 +838,7 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { ConstraintProto* new_ct = context->AddEnforcedConstraint(ct); LinearConstraintProto* lin = new_ct->mutable_linear(); - FillDomainInProto(0, std::numeric_limits::max(), lin); + FillDomainInProto(0, kint64max, lin); AddLinearExpressionToLinearConstraint(ct->lin_max().target(), 1, lin); AddLinearExpressionToLinearConstraint(expr, -1, lin); context->CanonicalizeLinearConstraint(new_ct); @@ -848,7 +866,7 @@ void ExpandLinMax(ConstraintProto* ct, PresolveContext* context) { ConstraintProto* new_ct = context->AddEnforcedConstraint({enforcement_literals[i]}); LinearConstraintProto* lin = new_ct->mutable_linear(); - FillDomainInProto(std::numeric_limits::min(), 0, lin); + FillDomainInProto(kint64min, 0, lin); AddLinearExpressionToLinearConstraint(ct->lin_max().target(), 1, lin); AddLinearExpressionToLinearConstraint(ct->lin_max().exprs(i), -1, lin); context->CanonicalizeLinearConstraint(new_ct); @@ -1020,7 +1038,6 @@ void ExpandVariableElement(ConstraintProto* ct, PresolveContext* context, mapping_ct->mutable_linear()->add_domain(mapping_offset); mapping_ct->mutable_linear()->add_domain(mapping_offset); - context->UpdateNewConstraintsVariableUsage(); context->MarkVariableAsRemoved(index_var); context->UpdateRuleStats( "element: expanded variable element with unused index"); @@ -1048,7 +1065,8 @@ void ExpandVariableElement(ConstraintProto* ct, PresolveContext* context, } if (domain_is_exact && is_unique) { - for (const int64_t v : context->DomainOf(index_var).Values()) { + const Domain index_domain = context->DomainOf(index_var); + for (const int64_t v : index_domain.Values()) { const int64_t index_lit = context->GetOrCreateVarValueEncoding(index_var, v); const int64_t index_value = AffineExpressionValueAt(index, v); @@ -1062,7 +1080,6 @@ void ExpandVariableElement(ConstraintProto* ct, PresolveContext* context, AddLinearExpressionToLinearConstraint(expr, 1, imply->mutable_linear()); context->CanonicalizeLinearConstraint(imply); } - context->UpdateNewConstraintsVariableUsage(); context->UpdateRuleStats( "element: expanded variable element with unused target"); context->MarkVariableAsRemoved(target_var); @@ -1129,7 +1146,7 @@ void ExpandVariableElement(ConstraintProto* ct, PresolveContext* context, context->CanonicalizeLinearConstraint(imply); // Note that this should have been checked at model validation. - DCHECK(!PossibleIntegerOverflow(*context->working_model, + DCHECK(!PossibleIntegerOverflow(context->WorkingModel(), imply->mutable_linear()->vars(), imply->mutable_linear()->coeffs())) << google::protobuf::ShortFormat(*imply); @@ -1798,8 +1815,8 @@ void ProcessOneCompressedColumn( // bit, especially the linear relaxation. BoolArgumentProto* no_support = use_exo && !value_is_multiple.contains(value) - ? context->working_model->add_constraints()->mutable_exactly_one() - : context->working_model->add_constraints()->mutable_bool_or(); + ? context->AddConstraint()->mutable_exactly_one() + : context->AddConstraint()->mutable_bool_or(); for (; i < pairs.size() && pairs[i].first == value; ++i) { no_support->add_literals(pairs[i].second); @@ -1885,8 +1902,8 @@ void AddSizeTwoTable( } } if (exclusive) { - BoolArgumentProto* exo = context->working_model->add_constraints() - ->mutable_exactly_one(); + BoolArgumentProto* exo = + context->AddConstraint()->mutable_exactly_one(); for (const int support_literal : support_literals) { exo->add_literals(support_literal); } @@ -1894,7 +1911,7 @@ void AddSizeTwoTable( ++num_exo_added; } else { BoolArgumentProto* bool_or = - context->working_model->add_constraints()->mutable_bool_or(); + context->AddConstraint()->mutable_bool_or(); for (const int support_literal : support_literals) { bool_or->add_literals(support_literal); } @@ -1966,7 +1983,7 @@ bool ReduceTableInPresenceOfUniqueVariableWithCosts( // Rewrite the tuples. // put the cost last. - int64_t min_cost = std::numeric_limits::max(); + int64_t min_cost = kint64max; std::vector temp; for (int i = 0; i < tuples->size(); ++i) { int64_t cost = 0; @@ -2104,7 +2121,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, domain_sizes.push_back(context->DomainOf(var).Size()); } if (last_column_is_cost) { - domain_sizes.push_back(std::numeric_limits::max()); + domain_sizes.push_back(kint64max); } // We start by compressing the table with kTableAnyValue only. @@ -2194,7 +2211,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, // Create one Boolean variable per tuple to indicate if it can still be // selected or not. Enforce an exactly one between them. BoolArgumentProto* exactly_one = - context->working_model->add_constraints()->mutable_exactly_one(); + context->AddConstraint()->mutable_exactly_one(); std::optional table_is_active_literal = std::nullopt; // Process enforcement literals. @@ -2205,8 +2222,7 @@ void CompressAndExpandPositiveTable(ConstraintProto* ct, context->NewBoolVarWithConjunction(ct->enforcement_literal()); // Adds table_is_active <=> and(enforcement_literals). - BoolArgumentProto* bool_or = - context->working_model->add_constraints()->mutable_bool_or(); + BoolArgumentProto* bool_or = context->AddConstraint()->mutable_bool_or(); bool_or->add_literals(table_is_active_literal.value()); for (const int lit : ct->enforcement_literal()) { context->AddImplication(table_is_active_literal.value(), lit); @@ -2344,8 +2360,7 @@ void ExpandPositiveTable(ConstraintProto* ct, PresolveContext* context) { return (void)context->NotifyThatModelIsUnsat(); } else { context->UpdateRuleStats("table: enforced and empty"); - BoolArgumentProto* bool_or = - context->working_model->add_constraints()->mutable_bool_or(); + BoolArgumentProto* bool_or = context->AddConstraint()->mutable_bool_or(); for (const int lit : ct->enforcement_literal()) { bool_or->add_literals(NegatedRef(lit)); } @@ -2407,13 +2422,13 @@ void ExpandPositiveTable(ConstraintProto* ct, PresolveContext* context) { } bool AllDiffShouldBeExpanded(const Domain& union_of_domains, - const ConstraintProto* ct, + const ConstraintProto& ct, PresolveContext* context) { if (union_of_domains.Size() > context->params().max_alldiff_domain_size()) { return false; } - const AllDifferentConstraintProto& proto = ct->all_diff(); + const AllDifferentConstraintProto& proto = ct.all_diff(); const int num_exprs = proto.exprs_size(); int num_fully_encoded = 0; for (int i = 0; i < num_exprs; ++i) { @@ -2475,7 +2490,7 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, // a single Boolean. single_bool = context->NewBoolVar("complex linear expansion"); } else { - clause = context->working_model->add_constraints()->mutable_bool_or(); + clause = context->AddConstraint()->mutable_bool_or(); for (const int ref : ct->enforcement_literal()) { clause->add_literals(NegatedRef(ref)); } @@ -2502,7 +2517,7 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, // Create a new constraint which is a copy of the original, but with a // simple sub-domain and enforcement literal. - ConstraintProto* new_ct = context->working_model->add_constraints(); + ConstraintProto* new_ct = context->AddConstraint(); *new_ct = *ct; new_ct->add_enforcement_literal(subdomain_literal); FillDomainInProto(Domain(lb, ub), new_ct->mutable_linear()); @@ -2519,7 +2534,7 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, } else { linear_is_enforced = context->NewBoolVar("complex linear expansion"); BoolArgumentProto* maintain_linear_is_enforced = - context->working_model->add_constraints()->mutable_bool_or(); + context->AddConstraint()->mutable_bool_or(); for (const int e_lit : enforcement_literals) { context->AddImplication(NegatedRef(e_lit), NegatedRef(linear_is_enforced)); @@ -2540,7 +2555,6 @@ void ExpandComplexLinearConstraint(int c, ConstraintProto* ct, context->UpdateRuleStats("linear: expanded complex rhs"); context->InitializeNewDomains(); - context->UpdateNewConstraintsVariableUsage(); context->UpdateConstraintVariableUsage(c); } @@ -2575,17 +2589,17 @@ bool IsVarEqOrNeqValue(PresolveContext* context, // terms, and with a tight domain ( == cst). // TODO(user): The above rule is complex. Revisit. void ScanModelAndDecideAllDiffExpansion( - const ConstraintProto* all_diff_ct, PresolveContext* context, + const ConstraintProto& all_diff_ct, PresolveContext* context, absl::flat_hash_set& domain_of_var_is_used, absl::flat_hash_set& bounds_of_var_are_used, absl::flat_hash_set& processed_variables, bool& expand, bool& keep) { - CHECK_EQ(all_diff_ct->constraint_case(), ConstraintProto::kAllDiff); + CHECK_EQ(all_diff_ct.constraint_case(), ConstraintProto::kAllDiff); bool at_least_one_var_domain_is_used = false; bool at_least_one_var_bound_is_used = false; // Scan variables. - for (const LinearExpressionProto& expr : all_diff_ct->all_diff().exprs()) { + for (const LinearExpressionProto& expr : all_diff_ct.all_diff().exprs()) { // Skip constant expressions. if (expr.vars().empty()) continue; DCHECK_EQ(1, expr.vars_size()); @@ -2609,8 +2623,7 @@ void ScanModelAndDecideAllDiffExpansion( // Skip artificial constraints. if (ct_index < 0) continue; - const ConstraintProto& other_ct = - context->working_model->constraints(ct_index); + const ConstraintProto& other_ct = context->Constraint(ct_index); switch (other_ct.constraint_case()) { case ConstraintProto::ConstraintCase::kBoolOr: break; @@ -2710,13 +2723,14 @@ void ScanModelAndDecideAllDiffExpansion( keep = at_least_one_var_bound_is_used; } -void MaybeExpandAllDiff(ConstraintProto* ct, PresolveContext* context, +void MaybeExpandAllDiff(int c, PresolveContext* context, absl::flat_hash_set& domain_of_var_is_used, absl::flat_hash_set& bounds_of_var_are_used, absl::flat_hash_set& processed_variable) { const bool expand_all_diff_from_parameters = context->params().expand_alldiff_constraints(); - AllDifferentConstraintProto& proto = *ct->mutable_all_diff(); + const ConstraintProto& ct = context->Constraint(c); + const AllDifferentConstraintProto& proto = ct.all_diff(); if (proto.exprs_size() <= 1) return; if (context->ModelIsUnsat()) return; @@ -2766,10 +2780,10 @@ void MaybeExpandAllDiff(ConstraintProto* ct, PresolveContext* context, if (fixed_expression_count > 1) { // Violates the definition of AllDifferent. - ExpandAlwaysFalseConstraint(ct, context); + ExpandAlwaysFalseConstraint(context->MutableConstraint(c), context); return; } else if (fixed_expression_count == 1 && - ct->enforcement_literal().empty()) { + ct.enforcement_literal().empty()) { // Remove values from other domains. for (const LinearExpressionProto& expr : possible_exprs) { if (context->IsFixed(expr)) continue; @@ -2800,7 +2814,7 @@ void MaybeExpandAllDiff(ConstraintProto* ct, PresolveContext* context, context->UpdateRuleStats( absl::StrCat("all_diff:", is_a_permutation ? " permutation" : "", " expanded", keep_after_expansion ? " and kept" : "")); - if (!keep_after_expansion) ct->Clear(); + if (!keep_after_expansion) context->ClearConstraint(c); } } // namespace @@ -2823,58 +2837,59 @@ void ExpandCpModel(PresolveContext* context) { bool has_all_diffs = false; // First pass: we look at constraints that may fully encode variables. - for (int c = 0; c < context->working_model->constraints_size(); ++c) { - ConstraintProto* const ct = context->working_model->mutable_constraints(c); + for (int c = 0; c < context->NumConstraints(); ++c) { + const ConstraintProto& ct = context->Constraint(c); bool skip = false; - switch (ct->constraint_case()) { + switch (ct.constraint_case()) { case ConstraintProto::kLinear: // If we only do expansion, we do that as part of the main loop. // This way we don't need to call FinalExpansionForLinearConstraint(). - if (ct->linear().domain().size() > 2 && + if (ct.linear().domain().size() > 2 && !context->params().cp_model_presolve()) { - ExpandComplexLinearConstraint(c, ct, context); + ExpandComplexLinearConstraint(c, context->MutableConstraint(c), + context); } break; case ConstraintProto::kReservoir: if (context->params().expand_reservoir_constraints()) { - ExpandReservoir(ct, context); + ExpandReservoir(context->MutableConstraint(c), context); } break; case ConstraintProto::kCumulative: if (context->params().encode_cumulative_as_reservoir()) { - EncodeCumulativeAsReservoir(ct, context); + EncodeCumulativeAsReservoir(context->MutableConstraint(c), context); } break; case ConstraintProto::kIntMod: - ExpandIntMod(ct, context); + ExpandIntMod(context->MutableConstraint(c), context); break; case ConstraintProto::kIntProd: - ExpandIntProd(ct, context); + ExpandIntProd(context->MutableConstraint(c), context); break; case ConstraintProto::kElement: - ExpandElement(ct, context); + ExpandElement(context->MutableConstraint(c), context); break; case ConstraintProto::kInverse: - ExpandInverse(ct, context); + ExpandInverse(context->MutableConstraint(c), context); break; case ConstraintProto::kAutomaton: - ExpandAutomaton(ct, context); + ExpandAutomaton(context->MutableConstraint(c), context); break; case ConstraintProto::kTable: if (!context->params().cp_model_presolve() || context->time_limit()->LimitReached()) { - CanonicalizeTable(context, ct); + CanonicalizeTable(context, context->MutableConstraint(c)); } - if (ct->table().negated()) { - ExpandNegativeTable(ct, context); + if (ct.table().negated()) { + ExpandNegativeTable(context->MutableConstraint(c), context); } else { - ExpandPositiveTable(ct, context); + ExpandPositiveTable(context->MutableConstraint(c), context); } break; case ConstraintProto::kLinMax: - if (ct->lin_max().exprs().size() <= + if (ct.lin_max().exprs().size() <= context->params().max_lin_max_size_for_expansion()) { - ExpandLinMax(ct, context); + ExpandLinMax(context->MutableConstraint(c), context); } break; case ConstraintProto::kAllDiff: @@ -2888,15 +2903,14 @@ void ExpandCpModel(PresolveContext* context) { if (skip) continue; // Nothing was done for this constraint. // Update variable-constraint graph. - context->UpdateNewConstraintsVariableUsage(); - if (ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { + if (ct.constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { context->UpdateConstraintVariableUsage(c); } // Early exit if the model is unsat. if (context->ModelIsUnsat()) { SOLVER_LOG(context->logger(), "UNSAT after expansion of ", - ProtobufShortDebugString(*ct)); + ProtobufShortDebugString(ct)); return; } } @@ -2908,18 +2922,19 @@ void ExpandCpModel(PresolveContext* context) { absl::flat_hash_set domain_of_var_is_used; absl::flat_hash_set bounds_of_var_are_used; absl::flat_hash_set processed_variables; - for (int i = 0; i < context->working_model->constraints_size(); ++i) { - ConstraintProto* const ct = context->working_model->mutable_constraints(i); + for (int c = 0; c < context->NumConstraints(); ++c) { bool skip = false; - switch (ct->constraint_case()) { + const ConstraintProto& ct = context->Constraint(c); + switch (ct.constraint_case()) { case ConstraintProto::kAtMostOne: case ConstraintProto::kExactlyOne: // We do those in the second pass since MaybeExpandAllDiff() below may // create such constraints. - ExpandEnforcedAtMostOneOrExactlyOneConstraint(ct, i, context); + ExpandEnforcedAtMostOneOrExactlyOneConstraint( + context->MutableConstraint(c), c, context); break; case ConstraintProto::kAllDiff: - MaybeExpandAllDiff(ct, context, domain_of_var_is_used, + MaybeExpandAllDiff(c, context, domain_of_var_is_used, bounds_of_var_are_used, processed_variables); break; default: @@ -2930,15 +2945,14 @@ void ExpandCpModel(PresolveContext* context) { if (skip) continue; // Nothing was done for this constraint. // Update variable-constraint graph. - context->UpdateNewConstraintsVariableUsage(); - if (ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { - context->UpdateConstraintVariableUsage(i); + if (ct.constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { + context->UpdateConstraintVariableUsage(c); } // Early exit if the model is unsat. if (context->ModelIsUnsat()) { SOLVER_LOG(context->logger(), "UNSAT after expansion of ", - ProtobufShortDebugString(*ct)); + ProtobufShortDebugString(ct)); return; } } @@ -2948,27 +2962,22 @@ void ExpandCpModel(PresolveContext* context) { // of the expansion phase. context->ClearPrecedenceCache(); - // Make sure the context is consistent. + // Make sure the context is consistent and the proto domains are up to date. context->InitializeNewDomains(); - - // Update any changed domain from the context. - for (int i = 0; i < context->working_model->variables_size(); ++i) { - FillDomainInProto(context->DomainOf(i), - context->working_model->mutable_variables(i)); - } - + context->WriteVariableDomainsToProto(); context->NotifyThatModelIsExpanded(); } void FinalExpansionForLinearConstraint(PresolveContext* context) { if (context->params().disable_constraint_expansion()) return; if (context->ModelIsUnsat()) return; - for (int c = 0; c < context->working_model->constraints_size(); ++c) { - ConstraintProto* const ct = context->working_model->mutable_constraints(c); - switch (ct->constraint_case()) { + for (int c = 0; c < context->NumConstraints(); ++c) { + const ConstraintProto& ct = context->Constraint(c); + switch (ct.constraint_case()) { case ConstraintProto::kLinear: - if (ct->linear().domain().size() > 2) { - ExpandComplexLinearConstraint(c, ct, context); + if (ct.linear().domain().size() > 2) { + ExpandComplexLinearConstraint(c, context->MutableConstraint(c), + context); } break; default: diff --git a/ortools/sat/cp_model_expand_test.cc b/ortools/sat/cp_model_expand_test.cc index 46036210ebb..51b3cb5c1b1 100644 --- a/ortools/sat/cp_model_expand_test.cc +++ b/ortools/sat/cp_model_expand_test.cc @@ -14,10 +14,12 @@ #include "ortools/sat/cp_model_expand.h" #include +#include #include #include "absl/container/btree_set.h" #include "absl/log/check.h" +#include "absl/random/random.h" #include "absl/strings/string_view.h" #include "gtest/gtest.h" #include "ortools/base/container_logging.h" @@ -1967,7 +1969,6 @@ TEST(ExpandAllDiffTest, GoldenTestWithEnforcementLiteral) { Model model; PresolveContext context(&model, &initial_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); ExpandCpModel(&context); const CpModelProto expected_model = ParseTestProto(R"pb( @@ -2082,6 +2083,57 @@ TEST(ExpandInverseTest, DuplicateAtDifferentPosition) { EXPECT_EQ(4, solutions.size()); } +TEST(ExpandInverseTest, ExpressionsWithDuplicateVariables) { + absl::BitGen random; + std::vector> coeff_and_offsets; + for (int coeff = -2; coeff < 3; ++coeff) { + for (int offset = -1; offset < 2; ++offset) { + coeff_and_offsets.push_back({coeff, offset}); + } + } + for (int i = 0; i < 1000; ++i) { + CpModelProto initial_model; + for (int j = 0; j < 4; ++j) { + auto* var = initial_model.add_variables(); + var->add_domain(-1); + var->add_domain(4); + } + auto* inverse = initial_model.add_constraints()->mutable_inverse(); + // Create an inverse constraint with direct expressions using variables + // {0, 1, 2, 3} and inverse expressions using variables {1, 2, 3, 4}. + for (int j = 0; j < 4; ++j) { + int k = absl::Uniform(random, 0, (int)coeff_and_offsets.size()); + auto* f_direct = inverse->add_f_expr_direct(); + if (coeff_and_offsets[k].first != 0) { + f_direct->add_vars(j); + f_direct->add_coeffs(coeff_and_offsets[k].first); + } + f_direct->set_offset(coeff_and_offsets[k].second); + + k = absl::Uniform(random, 0, (int)coeff_and_offsets.size()); + auto* f_inverse = inverse->add_f_expr_inverse(); + if (coeff_and_offsets[k].first != 0) { + f_inverse->add_vars((j + 1) % 4); + f_inverse->add_coeffs(coeff_and_offsets[k].first); + } + f_inverse->set_offset(coeff_and_offsets[k].second); + } + absl::btree_set> solutions; + const CpSolverResponse response = + SolveAndCheck(initial_model, "", &solutions); + + int num_expected_solutions = 0; + for (int j = 0; j < 6 * 6 * 6 * 6; ++j) { + if (SolutionIsFeasible(initial_model, {j / 216 - 1, (j / 36) % 6 - 1, + (j / 6) % 6 - 1, j % 6 - 1})) { + ++num_expected_solutions; + } + } + EXPECT_EQ(solutions.size(), num_expected_solutions) + << initial_model.DebugString(); + } +} + TEST(TableExpandTest, UsedToFail) { const CpModelProto initial_model = ParseTestProto(R"pb( variables { domain: [ 0, 3 ] } @@ -2310,7 +2362,7 @@ TEST(FinalExpansionForLinearConstraintTest, ComplexLinearExpansion) { PresolveContext context(&model, &initial_model, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); FinalExpansionForLinearConstraint(&context); @@ -2381,7 +2433,7 @@ TEST(FinalExpansionForLinearConstraintTest, ComplexLinearExpansionWithInteger) { PresolveContext context(&model, &initial_model, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); FinalExpansionForLinearConstraint(&context); @@ -2434,7 +2486,7 @@ TEST(FinalExpansionForLinearConstraintTest, PresolveContext context(&model, &initial_model, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); FinalExpansionForLinearConstraint(&context); diff --git a/ortools/sat/cp_model_lns.cc b/ortools/sat/cp_model_lns.cc index 7e39be61a9a..7a12c1f6372 100644 --- a/ortools/sat/cp_model_lns.cc +++ b/ortools/sat/cp_model_lns.cc @@ -43,6 +43,7 @@ #include "ortools/base/log_severity.h" #include "ortools/base/macros/buildenv.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/graph_base/connected_components.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_copy.h" @@ -97,6 +98,7 @@ NeighborhoodGeneratorHelper::NeighborhoodGeneratorHelper( } *model_proto_with_only_variables_.mutable_variables() = model_proto_.variables(); + InitializeHelperData(); RecomputeHelperData(); Synchronize(); @@ -189,8 +191,7 @@ bool NeighborhoodGeneratorHelper::ObjectiveDomainIsConstraining() const { const Domain obj_domain = ReadDomainFromProto(objective); const Domain inferred_domain = Domain(min_activity, max_activity) - .IntersectionWith( - Domain(std::numeric_limits::min(), obj_domain.Max())); + .IntersectionWith(Domain(kint64min, obj_domain.Max())); return !inferred_domain.IsIncludedIn(obj_domain); } @@ -212,8 +213,18 @@ void NeighborhoodGeneratorHelper::RecomputeHelperData() { absl::MutexLock graph_lock(graph_mutex_); absl::ReaderMutexLock domain_lock(domain_mutex_); + std::vector mapping; if (shared_clauses_ != nullptr) { - var_to_representative_ = shared_clauses_->GetRepresentatives(); + mapping = shared_clauses_->GetRepresentatives(); + } + if (!mapping.empty()) { + // GetRepresentatives() return the smallest possible vector, complete it. + const int num_vars = model_proto_with_only_variables_.variables().size(); + CHECK_LE(mapping.size(), num_vars); + mapping.reserve(num_vars); + while (mapping.size() < num_vars) { + mapping.push_back(mapping.size()); + } } // Do basic presolving to have a more precise graph. @@ -228,7 +239,6 @@ void NeighborhoodGeneratorHelper::RecomputeHelperData() { // TODO(user): We can also start from the previous simplified model instead. { Model local_model; - CpModelProto mapping_proto; // We want to replace the simplified_model_proto_ by a new one. Since // deleting an object in the arena doesn't free the memory, we also delete // and recreate the arena, but reusing the same storage. @@ -239,44 +249,25 @@ void NeighborhoodGeneratorHelper::RecomputeHelperData() { local_arena_storage_.resize(new_size); local_arena_ = std::make_unique( local_arena_storage_.data(), local_arena_storage_.size()); + + // Note that here we don't densify the space of indices, so we don't need + // any special logic for the constants. ModelCopy will already deal with + // them efficiently. simplified_model_proto_ = google::protobuf::Arena::Create(local_arena_.get()); - *simplified_model_proto_->mutable_variables() = - model_proto_with_only_variables_.variables(); - - const int num_vars = model_proto_with_only_variables_.variables_size(); - std::vector mapping; - std::vector reverse_mapping; - mapping.reserve(num_vars); - reverse_mapping.reserve(num_vars); - int first_fixed_literal = -1; - for (int var = 0; var < num_vars; ++var) { - if (IsConstant(var)) { - const int64_t value = ConstantValue(var); - // ModelCopy requires that if some literals are fixed, then one of them - // must not be removed by the mapping. - if (first_fixed_literal == -1 && (value == 0 || value == 1)) { - first_fixed_literal = var; - mapping.push_back(var); - } else { - mapping.push_back(kNoVariableMapping); - } - } else { - mapping.push_back(GetRepresentative(var)); - } - reverse_mapping.push_back(var); - } - - PresolveContext context(&local_model, simplified_model_proto_, - &mapping_proto); - ModelCopy copier(&context, mapping, reverse_mapping); + ModelCopy copier(simplified_model_proto_, &local_model, mapping); - // TODO(user): Not sure what to do if the model is UNSAT. - // This shouldn't matter as it should be dealt with elsewhere. - copier.ImportAndSimplifyConstraints(model_proto_); - copier.ImportObjective(model_proto_); + // When the model is unsat, we abort any update. + // This shouldn't matter as it should be dealt with elsewhere. + if (!copier.ImportVariables(model_proto_with_only_variables_)) return; + if (!copier.ImportAndSimplifyConstraints(model_proto_)) return; + if (!copier.ImportObjective(model_proto_)) return; + if (!copier.FinishCopy(model_proto_)) return; } + // We copied the model successfully, so commit. + var_to_representative_ = mapping; + // Compute the constraint <-> variable graph. // // TODO(user): Remove duplicate constraints? @@ -350,6 +341,7 @@ void NeighborhoodGeneratorHelper::RecomputeHelperData() { const int var = objective.vars(i); const int64_t coeff = objective.coeffs(i); DCHECK(RefIsPositive(var)); + DCHECK(!IsConstant(var)); DCHECK_NE(coeff, 0); is_in_objective_[var] = true; has_positive_objective_coefficient_[var] = coeff > 0; @@ -770,8 +762,8 @@ void ProcessDemandListFromCumulativeConstraint( // Checks if any pairs of tasks cannot overlap. int64_t sum_of_min_two_capacities = 2; if (capacity > 1) { - int64_t min1 = std::numeric_limits::max(); - int64_t min2 = std::numeric_limits::max(); + int64_t min1 = kint64max; + int64_t min2 = kint64max; for (const Demand& demand : demands) { if (demand.height <= min1) { min2 = min1; @@ -1046,7 +1038,7 @@ NeighborhoodGeneratorHelper::GetRoutingPathBooleanVariables( const CircuitConstraintProto& ct = ModelProto().constraints(i).circuit(); // Collect arcs. - int min_node = std::numeric_limits::max(); + int min_node = kint32max; tail_to_head_and_arc_bool_var.clear(); for (int i = 0; i < ct.literals_size(); ++i) { const int literal = ct.literals(i); @@ -1122,88 +1114,63 @@ Neighborhood NeighborhoodGeneratorHelper::FixGivenVariables( const Bitset64& variables_to_fix) const { const int num_variables = variables_to_fix.size(); Neighborhood neighborhood(num_variables); - neighborhood.delta.mutable_variables()->Reserve(num_variables); - // TODO(user): Maybe relax all variables in the objective when the number - // is small or negligible compared to the number of variables. - int unique_objective_variable = -1; + // Do a bit of filtering on what to fix compared to "variables_to_fix". + // We want only active variable in there (at the time we acquire the lock). + std::vector to_fix; + std::vector in_to_fix(num_variables, false); { absl::ReaderMutexLock graph_lock(graph_mutex_); + + // If there is an unique objective variable, we will always relax it. + // + // TODO(user): Maybe relax all variables in the objective when the number + // is small or negligible compared to the number of variables. + int unique_objective_variable = -1; if (simplified_model_proto_->has_objective() && simplified_model_proto_->objective().vars_size() == 1) { unique_objective_variable = simplified_model_proto_->objective().vars(0); } - } - // Fill in neighborhood.delta all variable domains. - int num_fixed = 0; - { - absl::ReaderMutexLock domain_lock(domain_mutex_); - for (int i = 0; i < num_variables; ++i) { - const IntegerVariableProto& current_var = - model_proto_with_only_variables_.variables(i); - IntegerVariableProto* new_var = neighborhood.delta.add_variables(); - - // We only copy the name in debug mode. - if (DEBUG_MODE) new_var->set_name(current_var.name()); - - if (variables_to_fix[i] && i != unique_objective_variable) { - ++num_fixed; - - // Note the use of DomainInProtoContains() instead of - // ReadDomainFromProto() as the later is slower and allocate memory. - const int64_t base_value = base_solution.solution(i); - if (DomainInProtoContains(current_var, base_value)) { - new_var->add_domain(base_value); - new_var->add_domain(base_value); - } else { - // If under the updated domain, the base solution is no longer valid, - // We should probably regenerate this neighborhood. But for now we - // just do a best effort and take the closest value. - const Domain domain = ReadDomainFromProto(current_var); - int64_t closest_value = domain.Min(); - int64_t closest_dist = std::abs(closest_value - base_value); - for (const ClosedInterval interval : domain) { - for (const int64_t value : {interval.start, interval.end}) { - const int64_t dist = std::abs(value - base_value); - if (dist < closest_dist) { - closest_value = value; - closest_dist = dist; - } - } - } - FillDomainInProto(Domain(closest_value, closest_value), new_var); - } - } else { - *new_var->mutable_domain() = current_var.domain(); + for (int var = 0; var < num_variables; ++var) { + if (!variables_to_fix[var]) continue; + + // We only fix representative. + // This is needed for correctness of the "connected component" logic. + // + // We don't need to fix all the other, since the initial copy to generate + // the LNS fragment will use the most up to date equivalence information. + const int rep = PositiveRef(GetRepresentative(var)); + if (rep == unique_objective_variable) continue; + + // Skip non-active variable (they should be already fixed). + if (!active_variables_set_[rep]) continue; + + if (!in_to_fix[rep]) { + in_to_fix[rep] = true; + to_fix.push_back(rep); } } - } - // Fill some statistic fields and detect if we cover a full component. - // - // TODO(user): If there is just one component, we can skip some computation. - { - absl::ReaderMutexLock graph_lock(graph_mutex_); - std::vector count(components_.size(), 0); - const int num_variables = neighborhood.delta.variables().size(); - for (int var = 0; var < num_variables; ++var) { - const auto& domain = neighborhood.delta.variables(var).domain(); - if (domain.size() != 2 || domain[0] != domain[1]) { + std::vector component_was_altered(components_.size(), false); + for (const int var : active_variables_) { + DCHECK_EQ(GetRepresentative(var), var); + if (in_to_fix[var]) { + const int c = var_to_component_index_[var]; + DCHECK_NE(c, -1); + component_was_altered[c] = true; + } else { + // Anything non-fixed is considered "relaxed". + // TODO(user): shall we ignore fixed variables. ++neighborhood.num_relaxed_variables; if (is_in_objective_[var]) { ++neighborhood.num_relaxed_variables_in_objective; } - const int c = var_to_component_index_[var]; - if (c != -1) { - DCHECK_EQ(GetRepresentative(var), var); - count[c]++; - } } } for (int i = 0; i < components_.size(); ++i) { - if (count[i] == components_[i].size()) { + if (!component_was_altered[i]) { neighborhood.variables_that_can_be_fixed_to_local_optimum.insert( neighborhood.variables_that_can_be_fixed_to_local_optimum.end(), components_[i].begin(), components_[i].end()); @@ -1211,11 +1178,51 @@ Neighborhood NeighborhoodGeneratorHelper::FixGivenVariables( } } + // We start by copying the current domains. Note that if ignore_names is true + // (the default), we should already have no names here so we don't waste that + // space. + { + absl::ReaderMutexLock domain_lock(domain_mutex_); + *neighborhood.delta.mutable_variables() = + model_proto_with_only_variables_.variables(); + } + + // Do the actual fixing using the base solution. + for (const int var : to_fix) { + IntegerVariableProto* var_proto = neighborhood.delta.mutable_variables(var); + + // Note the use of DomainInProtoContains() instead of + // ReadDomainFromProto() as the later is slower and allocate memory. + int64_t fixed_value = base_solution.solution(var); + if (!DomainInProtoContains(*var_proto, fixed_value)) { + // If under the updated domain, the base solution is no longer valid, + // We should probably regenerate this neighborhood. But for now we + // just do a best effort and take the closest value. + const Domain domain = ReadDomainFromProto(*var_proto); + int64_t closest_value = domain.Min(); + int64_t closest_dist = std::abs(closest_value - fixed_value); + for (const ClosedInterval interval : domain) { + for (const int64_t value : {interval.start, interval.end}) { + const int64_t dist = std::abs(value - fixed_value); + if (dist < closest_dist) { + closest_value = value; + closest_dist = dist; + } + } + } + fixed_value = closest_value; + } + + // Fix that variable domain. + FillDomainInProto(Domain(fixed_value, fixed_value), var_proto); + } + // If the objective domain might cut the optimal solution, we cannot exploit - // the connected components. We compute this outside the mutex to avoid - // any deadlock risk. + // the connected components. // // TODO(user): We could handle some complex domain (size > 2). + // TODO(user): We could still handle component that do not contain objective + // terms. if (model_proto_.has_objective() && (model_proto_.objective().domain().size() != 2 || shared_response_->GetInnerObjectiveLowerBound() < @@ -1223,6 +1230,7 @@ Neighborhood NeighborhoodGeneratorHelper::FixGivenVariables( neighborhood.variables_that_can_be_fixed_to_local_optimum.clear(); } + const int num_fixed = to_fix.size(); const int num_relaxed = num_variables - num_fixed; neighborhood.delta.mutable_solution_hint()->mutable_vars()->Reserve( num_relaxed); @@ -1706,7 +1714,7 @@ Neighborhood DecompositionGraphNeighborhoodGenerator::Generate( absl::BitGenRef random) { int max_width = 0; int size_at_min_width_after_100; - int min_width_after_100 = std::numeric_limits::max(); + int min_width_after_100 = kint32max; int num_zero_score = 0; std::vector relaxed_variables; @@ -1888,7 +1896,7 @@ ConstraintProto DistanceToBoundsSmallerThanConstraint( linear->add_coeffs(-1); linear->add_vars(var); } - linear->add_domain(std::numeric_limits::min()); + linear->add_domain(kint64min); linear->add_domain(rhs - lhs_constant_value); return new_constraint; } @@ -2142,7 +2150,7 @@ namespace { void AddPrecedence(const LinearExpressionProto& before, const LinearExpressionProto& after, CpModelProto* model) { LinearConstraintProto* linear = model->add_constraints()->mutable_linear(); - linear->add_domain(std::numeric_limits::min()); + linear->add_domain(kint64min); linear->add_domain(after.offset() - before.offset()); for (int i = 0; i < before.vars_size(); ++i) { linear->add_vars(before.vars(i)); diff --git a/ortools/sat/cp_model_lns.h b/ortools/sat/cp_model_lns.h index 85349bc65dc..01b42b8503b 100644 --- a/ortools/sat/cp_model_lns.h +++ b/ortools/sat/cp_model_lns.h @@ -124,6 +124,9 @@ class NeighborhoodGeneratorHelper : public SubSolver { // Returns the LNS fragment where the given variables are fixed to the value // they take in the given solution. + // + // Note that we will first transform 'variables_to_fix' so that it contains + // only active variables. Neighborhood FixGivenVariables(const CpSolverResponse& base_solution, const Bitset64& variables_to_fix) const; diff --git a/ortools/sat/cp_model_lns_test.cc b/ortools/sat/cp_model_lns_test.cc index 789044131ec..e23735ed71c 100644 --- a/ortools/sat/cp_model_lns_test.cc +++ b/ortools/sat/cp_model_lns_test.cc @@ -621,19 +621,12 @@ TEST(NeighborhoodGeneratorHelperTest, FixGivenVariables) { Bitset64 variables_to_fix(helper.NumVariables()); for (const int var : {2, 0}) variables_to_fix.Set(var); const Neighborhood n = helper.FixGivenVariables(response, variables_to_fix); - const CpModelProto expected_output = - DEBUG_MODE ? ParseTestProto(R"pb( - variables { name: "x" domain: 2 domain: 2 } - variables { name: "y" domain: 0 domain: 10 } - variables { name: "z" domain: 4 domain: 4 } - solution_hint { vars: 1 values: 3 } - )pb") - : ParseTestProto(R"pb( - variables { domain: 2 domain: 2 } - variables { domain: 0 domain: 10 } - variables { domain: 4 domain: 4 } - solution_hint { vars: 1 values: 3 } - )pb"); + const CpModelProto expected_output = ParseTestProto(R"pb( + variables { name: "x" domain: 2 domain: 2 } + variables { name: "y" domain: 0 domain: 10 } + variables { name: "z" domain: 4 domain: 4 } + solution_hint { vars: 1 values: 3 } + )pb"); EXPECT_THAT(n.delta, testing::EqualsProto(expected_output)); } diff --git a/ortools/sat/cp_model_loader.cc b/ortools/sat/cp_model_loader.cc index cc7bacdb6ed..88716e550b5 100644 --- a/ortools/sat/cp_model_loader.cc +++ b/ortools/sat/cp_model_loader.cc @@ -104,6 +104,9 @@ void LoadVariables(const CpModelProto& model_proto, auto* mapping = m->GetOrCreate(); const int num_proto_variables = model_proto.variables_size(); + // We assume this is alive during the whole solve. + mapping->model_proto_ = &model_proto; + // All [0, 1] variables always have a corresponding Boolean, even if it is // fixed to 0 (domain == [0,0]) or fixed to 1 (domain == [1,1]). { @@ -188,7 +191,7 @@ void LoadVariables(const CpModelProto& model_proto, IndexReferences refs; for (int c = 0; c < model_proto.constraints_size(); ++c) { const ConstraintProto& ct = model_proto.constraints(c); - refs = GetReferencesUsedByConstraint(ct); + GetReferencesUsedByConstraint(ct, &refs.variables, &refs.literals); for (const int ref : refs.variables) { used_variables.insert(PositiveRef(ref)); } @@ -1012,6 +1015,7 @@ void LoadBoolOrConstraint(const ConstraintProto& ct, Model* m) { void LoadBoolAndConstraint(const ConstraintProto& ct, Model* m) { auto* mapping = m->GetOrCreate(); std::vector literals; + literals.reserve(ct.enforcement_literal().size() + 1); for (const int ref : ct.enforcement_literal()) { literals.push_back(mapping->Literal(ref).Negated()); } @@ -1278,6 +1282,8 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { } } + // TODO(user): We should probably reuse memory to avoid allocating for large + // problem with many small constraints. auto* integer_trail = m->GetOrCreate(); std::vector vars = mapping->Integers(ct.linear().vars()); std::vector coeffs(ct.linear().coeffs().begin(), @@ -1411,12 +1417,15 @@ void LoadLinearConstraint(const ConstraintProto& ct, Model* m) { /*use_lower_bound=*/(min_sum < lb), lb, /*use_upper_bound=*/(max_sum > ub), ub, &enforcement_literals, &cst); } else { - if (min_sum < lb) { - AddWeightedSumGreaterOrEqual(enforcement_literals, vars, coeffs, lb, m); - } if (max_sum > ub) { AddWeightedSumLowerOrEqual(enforcement_literals, vars, coeffs, ub, m); } + if (min_sum < lb) { + // AddWeightedSumGreaterOrEqual() allocates a vector, so we do the + // conversion here since we already allocated one. + for (IntegerValue& ref : coeffs) ref = -ref; + AddWeightedSumLowerOrEqual(enforcement_literals, vars, coeffs, -lb, m); + } } return; } diff --git a/ortools/sat/cp_model_mapping.h b/ortools/sat/cp_model_mapping.h index 18af79df182..7d3e21f6957 100644 --- a/ortools/sat/cp_model_mapping.h +++ b/ortools/sat/cp_model_mapping.h @@ -233,6 +233,9 @@ class CpModelMapping { // Returns the number of variables in the loaded proto. int NumProtoVariables() const { return integers_.size(); } + // The current cp_model proto. + CpModelProto const* ModelProto() const { return model_proto_; } + private: friend void LoadVariables(const CpModelProto& model_proto, bool view_all_booleans_as_integers, Model* m); @@ -259,6 +262,8 @@ class CpModelMapping { // differently during linearization. absl::flat_hash_set encoding_ct_; + CpModelProto const* model_proto_ = nullptr; + int64_t num_non_boolean_integers_ = 0; }; diff --git a/ortools/sat/cp_model_postsolve.cc b/ortools/sat/cp_model_postsolve.cc index 81e9058313c..74f014d8c43 100644 --- a/ortools/sat/cp_model_postsolve.cc +++ b/ortools/sat/cp_model_postsolve.cc @@ -22,6 +22,7 @@ #include "absl/log/log.h" #include "absl/types/span.h" #include "ortools/base/log_severity.h" +#include "ortools/base/types.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" @@ -269,7 +270,7 @@ bool LinearExpressionIsFixed(const LinearExpressionProto& expr, // support post-solving the case where whatever the value of all expression, // there will be a valid target. void PostsolveLinMax(const ConstraintProto& ct, std::vector* domains) { - int64_t max_value = std::numeric_limits::min(); + int64_t max_value = kint64min; for (const LinearExpressionProto& expr : ct.lin_max().exprs()) { // In most case all expression are fixed, except in the corner case where // one of the expression refer to the target itself ! diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index 95a6846c4cf..670423645f8 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -54,6 +54,7 @@ #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" #include "ortools/base/timer.h" +#include "ortools/base/types.h" #include "ortools/graph_base/strongly_connected_components.h" #include "ortools/graph_base/topologicalsorter.h" #include "ortools/port/proto_utils.h" @@ -83,6 +84,7 @@ #include "ortools/sat/sat_inprocessing.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" +#include "ortools/sat/scheduling_model.h" #include "ortools/sat/simplification.h" #include "ortools/sat/solution_crush.h" #include "ortools/sat/util.h" @@ -137,18 +139,20 @@ bool CpModelPresolver::RemoveConstraint(ConstraintProto* ct) { // Now that they have served their purpose, we also remove dummy constraints, // otherwise that causes issue because our model are invalid in tests. void CpModelPresolver::RemoveEmptyConstraints() { + // This should only be done at the end of presolve, so it is fine to break + // context_ invariants. + CpModelProto* cp_model = context_->UnsafeMutableWorkingModel(); + + const int old_num_non_empty_constraints = context_->NumConstraints(); interval_representative_.clear(); - std::vector interval_mapping(context_->working_model->constraints_size(), - -1); + std::vector interval_mapping(old_num_non_empty_constraints, -1); int new_num_constraints = 0; - const int old_num_non_empty_constraints = - context_->working_model->constraints_size(); for (int c = 0; c < old_num_non_empty_constraints; ++c) { - const auto type = context_->working_model->constraints(c).constraint_case(); + const auto type = context_->Constraint(c).constraint_case(); if (type == ConstraintProto::CONSTRAINT_NOT_SET) continue; if (type == ConstraintProto::kDummyConstraint) continue; - context_->working_model->mutable_constraints(new_num_constraints) - ->Swap(context_->working_model->mutable_constraints(c)); + cp_model->mutable_constraints(new_num_constraints) + ->Swap(cp_model->mutable_constraints(c)); if (type == ConstraintProto::kInterval) { // Warning: interval_representative_ holds a pointer to the working model // to compute hashes, so we need to be careful about not changing a @@ -156,25 +160,26 @@ void CpModelPresolver::RemoveEmptyConstraints() { const auto [it, inserted] = interval_representative_.insert( {new_num_constraints, new_num_constraints}); interval_mapping[c] = it->second; - if (it->second != new_num_constraints) { + if (!inserted) { context_->UpdateRuleStats( "intervals: change duplicate index across constraints"); continue; } } - new_num_constraints++; - } - google::protobuf::util::Truncate( - context_->working_model->mutable_constraints(), new_num_constraints); - for (ConstraintProto& ct_ref : - *context_->working_model->mutable_constraints()) { + + // After first copy, interval should always be defined before they are + // used, so we can remap as we re-index. ApplyToAllIntervalIndices( [&interval_mapping](int* ref) { *ref = interval_mapping[*ref]; CHECK_NE(-1, *ref); }, - &ct_ref); + cp_model->mutable_constraints(new_num_constraints)); + + new_num_constraints++; } + google::protobuf::util::Truncate(cp_model->mutable_constraints(), + new_num_constraints); } bool CpModelPresolver::PresolveEnforcementLiteral(ConstraintProto* ct, @@ -257,7 +262,7 @@ bool CpModelPresolver::PresolveBoolXor(ConstraintProto* ct) { int new_size = 0; bool changed = false; int num_true_literals = 0; - int true_literal = std::numeric_limits::min(); + int true_literal = kint32min; for (const int literal : ct->bool_xor().literals()) { // TODO(user): More generally, if a variable appear in only bool xor // constraints, we can simply eliminate it using linear algebra on Z/2Z. @@ -333,14 +338,13 @@ bool CpModelPresolver::PresolveBoolXor(ConstraintProto* ct) { return false; } } - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("bool_xor: two active literals"); return RemoveConstraint(ct); } // TODO(user): maybe replace the enforced XOR by an enforced equality? } if (num_true_literals % 2 == 1) { - CHECK_NE(true_literal, std::numeric_limits::min()); + CHECK_NE(true_literal, kint32min); ct->mutable_bool_xor()->set_literals(new_size++, true_literal); } if (num_true_literals > 1) { @@ -560,7 +564,8 @@ bool CpModelPresolver::PresolveBoolAnd(ConstraintProto* ct) { return changed; } -bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) { +bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct, + bool use_dual_reduction) { bool is_at_most_one = ct->constraint_case() == ConstraintProto::kAtMostOne; const std::string name = is_at_most_one ? "at_most_one: " : "exactly_one: "; auto* literals = is_at_most_one @@ -631,23 +636,25 @@ bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) { } // A singleton variable with or without cost can be removed. See below. - if (context_->VariableIsUniqueAndRemovable(literal)) { - // A variable that doesn't appear in the objective can be seen as - // appearing with a coefficient of zero. - singleton_literal_with_cost.push_back({literal, 0}); - continue; - } - if (context_->VariableWithCostIsUniqueAndRemovable(literal)) { - const auto it = context_->ObjectiveMap().find(PositiveRef(literal)); - DCHECK(it != context_->ObjectiveMap().end()); - if (RefIsPositive(literal)) { - singleton_literal_with_cost.push_back({literal, it->second}); - } else { - // Note that we actually just store the objective change if this literal - // is true compared to it being false. - singleton_literal_with_cost.push_back({literal, -it->second}); + if (use_dual_reduction) { + if (context_->VariableIsUniqueAndRemovable(literal)) { + // A variable that doesn't appear in the objective can be seen as + // appearing with a coefficient of zero. + singleton_literal_with_cost.push_back({literal, 0}); + continue; + } + if (context_->VariableWithCostIsUniqueAndRemovable(literal)) { + const auto it = context_->ObjectiveMap().find(PositiveRef(literal)); + DCHECK(it != context_->ObjectiveMap().end()); + if (RefIsPositive(literal)) { + singleton_literal_with_cost.push_back({literal, it->second}); + } else { + // Note that we actually just store the objective change if this + // literal is true compared to it being false. + singleton_literal_with_cost.push_back({literal, -it->second}); + } + continue; } - continue; } context_->tmp_literals.push_back(literal); @@ -727,11 +734,12 @@ bool CpModelPresolver::PresolveAtMostOrExactlyOne(ConstraintProto* ct) { return changed; } -bool CpModelPresolver::PresolveAtMostOne(ConstraintProto* ct) { +bool CpModelPresolver::PresolveAtMostOne(ConstraintProto* ct, + bool use_dual_reduction) { if (context_->ModelIsUnsat()) return false; CHECK(!HasEnforcementLiteral(*ct)); - const bool changed = PresolveAtMostOrExactlyOne(ct); + const bool changed = PresolveAtMostOrExactlyOne(ct, use_dual_reduction); if (ct->constraint_case() != ConstraintProto::kAtMostOne) return changed; // Size zero: ok. @@ -753,7 +761,8 @@ bool CpModelPresolver::PresolveAtMostOne(ConstraintProto* ct) { bool CpModelPresolver::PresolveExactlyOne(ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; CHECK(!HasEnforcementLiteral(*ct)); - const bool changed = PresolveAtMostOrExactlyOne(ct); + const bool changed = + PresolveAtMostOrExactlyOne(ct, /*use_dual_reduction=*/true); if (ct->constraint_case() != ConstraintProto::kExactlyOne) return changed; // Size zero: UNSAT. @@ -887,7 +896,7 @@ bool CpModelPresolver::PropagateAndReduceAffineMax(ConstraintProto* ct) { CHECK_LE(context_->DomainOf(unique_var).Size(), 1000); for (const int64_t value : context_->DomainOf(unique_var).Values()) { - int64_t current_max = std::numeric_limits::min(); + int64_t current_max = kint64min; // Fill tmp_values and compute current_max; for (int i = 0; i < num_exprs; ++i) { @@ -1004,7 +1013,7 @@ bool CpModelPresolver::PropagateAndReduceLinMax(ConstraintProto* ct) { // Update target domain (if it is not a complex expression). { int64_t infered_min = context_->MinOf(target); - int64_t infered_max = std::numeric_limits::min(); + int64_t infered_max = kint64min; for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { infered_min = std::max(infered_min, context_->MinOf(expr)); infered_max = std::max(infered_max, context_->MaxOf(expr)); @@ -1042,7 +1051,7 @@ bool CpModelPresolver::PropagateAndReduceLinMax(ConstraintProto* ct) { // Note that we must keep an expression >= target_min though, for corner // case like [2,3] = max([2], [0][3]); bool has_greater_or_equal_to_target_min = false; - int64_t max_at_index_to_keep = std::numeric_limits::min(); + int64_t max_at_index_to_keep = kint64min; int index_to_keep = -1; for (int i = 0; i < ct->lin_max().exprs_size(); ++i) { const LinearExpressionProto& expr = ct->lin_max().exprs(i); @@ -1081,6 +1090,27 @@ bool CpModelPresolver::PropagateAndReduceLinMax(ConstraintProto* ct) { return changed; } +void CpModelPresolver::AddLinear2ToModel(const LinearExpression2& linear2, + int64_t lb, int64_t ub) { + auto* ct = context_->AddConstraint(); + auto* linear = ct->mutable_linear(); + linear->add_domain(lb); + linear->add_domain(ub); + const absl::Span vars = linear2.non_zero_vars(); + const absl::Span coeffs = linear2.non_zero_coeffs(); + linear->mutable_vars()->Reserve(vars.size()); + linear->mutable_coeffs()->Reserve(vars.size()); + for (int i = 0; i < vars.size(); ++i) { + linear->add_vars(GetPositiveOnlyIndex(vars[i]).value()); + linear->add_coeffs(VariableIsPositive(vars[i]) ? coeffs[i].value() + : -coeffs[i].value()); + } + bool changed = false; + (void)CanonicalizeLinear(ct, &changed); + context_->UpdateRuleStats("linear2: added to model"); + known_model_linear2_.Add(linear2, lb, ub); +} + bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; // TODO(user): add support for this case. @@ -1093,9 +1123,9 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { for (const LinearExpressionProto& e : ct->lin_max().exprs()) { if (LinearExpressionProtosAreEqual(e, target)) continue; LinearConstraintProto* prec = - context_->working_model->add_constraints()->mutable_linear(); + context_->AddConstraint()->mutable_linear(); prec->add_domain(0); - prec->add_domain(std::numeric_limits::max()); + prec->add_domain(kint64max); AddLinearExpressionToLinearConstraint(target, 1, prec); AddLinearExpressionToLinearConstraint(e, -1, prec); } @@ -1160,6 +1190,16 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { ct->lin_max().exprs(i).offset()); const RelationStatus status = known_linear2_.GetStatus(expr2, lb, ub); if (status == RelationStatus::IS_TRUE) { + if (known_model_linear2_.GetStatus(expr2, lb, ub) != + RelationStatus::IS_TRUE) { + // Subtle: the linear2 might have been indirectly deduced using this + // lin_max constraint. The linear2 could be encoded as a boolean + // when loading the model on probing and then this boolean could + // have been assigned using clauses that are coming from this + // lin_max during propagation. But since we know that those bounds + // are always true, adding it to the model is safe. + AddLinear2ToModel(expr2, lb.value(), ub.value()); + } simplified = true; can_be_removed[i] = true; break; @@ -1186,7 +1226,7 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { // new constraint otherwise it might not be processed again. if (ct->lin_max().exprs().size() == 1) { context_->UpdateRuleStats("lin_max: converted to equality"); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); *new_ct = *ct; // copy name and potential reification. auto* arg = new_ct->mutable_linear(); const LinearExpressionProto& a = ct->lin_max().target(); @@ -1201,7 +1241,6 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { } arg->add_domain(b.offset() - a.offset()); arg->add_domain(b.offset() - a.offset()); - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -1236,8 +1275,8 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { // Checks if the affine target domain is constraining. bool linear_target_domain_contains_max_domain = false; if (ExpressionContainsSingleRef(target)) { // target = +/- var. - int64_t infered_min = std::numeric_limits::min(); - int64_t infered_max = std::numeric_limits::min(); + int64_t infered_min = kint64min; + int64_t infered_max = kint64min; for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { infered_min = std::max(infered_min, context_->MinOf(expr)); infered_max = std::max(infered_max, context_->MaxOf(expr)); @@ -1311,10 +1350,9 @@ bool CpModelPresolver::PresolveLinMax(int c, ConstraintProto* ct) { (context_->ObjectiveCoeff(target.vars(0)) > 0)) { context_->UpdateRuleStats("lin_max: rewrite with precedences"); for (const LinearExpressionProto& expr : ct->lin_max().exprs()) { - LinearConstraintProto* prec = - context_->working_model->add_constraints()->mutable_linear(); + LinearConstraintProto* prec = context_->AddConstraint()->mutable_linear(); prec->add_domain(0); - prec->add_domain(std::numeric_limits::max()); + prec->add_domain(kint64max); AddLinearExpressionToLinearConstraint(target, 1, prec); AddLinearExpressionToLinearConstraint(expr, -1, prec); } @@ -1424,7 +1462,7 @@ bool CpModelPresolver::PresolveLinMaxWhenAllBoolean(ConstraintProto* ct) { context_->UpdateRuleStats("lin_max: all booleans"); // target_ref => at_least_one(max_literals); - ConstraintProto* clause = context_->working_model->add_constraints(); + ConstraintProto* clause = context_->AddConstraint(); clause->add_enforcement_literal(target_ref); clause->mutable_bool_or(); for (const int lit : max_literals) { @@ -1438,7 +1476,7 @@ bool CpModelPresolver::PresolveLinMaxWhenAllBoolean(ConstraintProto* ct) { if (!min_is_reachable) { // not(target_ref) => at_least_one(min_literals). - ConstraintProto* clause = context_->working_model->add_constraints(); + ConstraintProto* clause = context_->AddConstraint(); clause->add_enforcement_literal(NegatedRef(target_ref)); clause->mutable_bool_or(); for (const int lit : min_literals) { @@ -1446,7 +1484,6 @@ bool CpModelPresolver::PresolveLinMaxWhenAllBoolean(ConstraintProto* ct) { } } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -1464,7 +1501,7 @@ bool CpModelPresolver::PropagateAndReduceIntAbs(ConstraintProto* ct) { const Domain expr_domain = context_->DomainSuperSetOf(expr); const Domain new_target_domain = expr_domain.UnionWith(expr_domain.Negation()) - .IntersectionWith({0, std::numeric_limits::max()}); + .IntersectionWith({0, kint64max}); bool target_domain_modified = false; if (!context_->IntersectDomainWith(target_expr, new_target_domain, &target_domain_modified)) { @@ -1481,9 +1518,8 @@ bool CpModelPresolver::PropagateAndReduceIntAbs(ConstraintProto* ct) { // Propagate from target domain to variable. { - const Domain target_domain = - context_->DomainSuperSetOf(target_expr) - .IntersectionWith(Domain(0, std::numeric_limits::max())); + const Domain target_domain = context_->DomainSuperSetOf(target_expr) + .IntersectionWith(Domain(0, kint64max)); const Domain new_expr_domain = target_domain.UnionWith(target_domain.Negation()); bool expr_domain_modified = false; @@ -1505,7 +1541,7 @@ bool CpModelPresolver::PropagateAndReduceIntAbs(ConstraintProto* ct) { // Convert to equality if the sign of expr is fixed. if (context_->MinOf(expr) >= 0) { context_->UpdateRuleStats("lin_max: converted abs to equality"); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); new_ct->set_name(ct->name()); auto* arg = new_ct->mutable_linear(); arg->add_domain(0); @@ -1516,13 +1552,12 @@ bool CpModelPresolver::PropagateAndReduceIntAbs(ConstraintProto* ct) { if (!CanonicalizeLinear(new_ct, &changed)) { return true; } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } if (context_->MaxOf(expr) <= 0) { context_->UpdateRuleStats("lin_max: converted abs to equality"); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); new_ct->set_name(ct->name()); auto* arg = new_ct->mutable_linear(); arg->add_domain(0); @@ -1533,7 +1568,6 @@ bool CpModelPresolver::PropagateAndReduceIntAbs(ConstraintProto* ct) { if (!CanonicalizeLinear(new_ct, &changed)) { return true; } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -1655,13 +1689,12 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { context_->UpdateRuleStats("int_prod: constant product"); } else { // Replace ct with an enforced linear "target == constant_factor". - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); LinearConstraintProto* const lin = new_ct->mutable_linear(); lin->add_domain(constant_factor); lin->add_domain(constant_factor); AddLinearExpressionToLinearConstraint(ct->int_prod().target(), 1, lin); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("enforced int_prod: constant product"); } return RemoveConstraint(ct); @@ -1687,29 +1720,25 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { // Replace with linear if it cannot overflow. if (ct->int_prod().exprs().size() == 1) { - LinearExpressionProto* const target = - ct->mutable_int_prod()->mutable_target(); - ConstraintProto* const new_ct = context_->working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); - LinearConstraintProto* const lin = new_ct->mutable_linear(); - - if (context_->IsFixed(*target)) { - int64_t target_value = context_->FixedValue(*target); + if (context_->IsFixed(ct->int_prod().target())) { + const int64_t target_value = + context_->FixedValue(ct->int_prod().target()); if (target_value % constant_factor != 0) { - context_->working_model->mutable_constraints()->RemoveLast(); return MarkConstraintAsFalse( ct, "int_prod: product incompatible with fixed target"); } + // expression == target_value / constant_factor. + ConstraintProto* new_ct = context_->AddEnforcedConstraint(ct); + LinearConstraintProto* const lin = new_ct->mutable_linear(); lin->add_domain(target_value / constant_factor); lin->add_domain(target_value / constant_factor); AddLinearExpressionToLinearConstraint(ct->int_prod().exprs(0), 1, lin); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("int_prod: expression is constant"); return RemoveConstraint(ct); } - const int64_t target_divisor = LinearExpressionGcd(*target); + const int64_t target_divisor = LinearExpressionGcd(ct->int_prod().target()); // Reduce coefficients. const int64_t gcd = @@ -1717,10 +1746,12 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { static_cast(std::abs(target_divisor))); if (gcd != 1) { constant_factor /= gcd; - DivideLinearExpression(gcd, target); + DivideLinearExpression(gcd, ct->mutable_int_prod()->mutable_target()); } // expression * constant_factor = target. + ConstraintProto* new_ct = context_->AddEnforcedConstraint(ct); + LinearConstraintProto* const lin = new_ct->mutable_linear(); lin->add_domain(0); lin->add_domain(0); const bool overflow = !SafeAddLinearExpressionToLinearConstraint( @@ -1730,13 +1761,12 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { // Check for overflow. if (overflow || - PossibleIntegerOverflow(*context_->working_model, lin->vars(), + PossibleIntegerOverflow(context_->WorkingModel(), lin->vars(), lin->coeffs(), lin->domain(0))) { - context_->working_model->mutable_constraints()->RemoveLast(); // The constant factor will be handled by the creation of an affine // relation below. + context_->RemoveLastConstraint(); } else { // Replace with a linear equation. - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("int_prod: linearize product by constant"); return RemoveConstraint(ct); } @@ -1758,7 +1788,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { if (HasEnforcementLiteral(*ct) || CapProd(constant_factor, std::max(context_->MaxOf(old_target), -context_->MinOf(old_target))) >= - std::numeric_limits::max() / 2) { + kint64max / 2) { // Restore the original constraint (we cannot add back a new term for // the constant factor: this may create a constraint with more than 2 // terms). @@ -1813,10 +1843,10 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { // exactly sure we are correct here since it depends on the evaluation // order. Similarly, by introducing intermediate variable we might loose // solution if this intermediate variable value do not fit on an int64_t. - if (new_coeff > absl::int128(std::numeric_limits::max()) || - new_coeff < absl::int128(std::numeric_limits::min()) || - new_offset > absl::int128(std::numeric_limits::max()) || - new_offset < absl::int128(std::numeric_limits::min())) { + if (new_coeff > absl::int128(kint64max) || + new_coeff < absl::int128(kint64min) || + new_offset > absl::int128(kint64max) || + new_offset < absl::int128(kint64min)) { return MarkConstraintAsFalse( ct, "int_prod: overflow during simplification"); } @@ -1904,22 +1934,14 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { // We can rewrite this as: // u * other_expr = B, if b = false; // (u + v) * other_expr = B, if b = true - ConstraintProto* constraint_for_false = - context_->working_model->add_constraints(); - ConstraintProto* constraint_for_true = - context_->working_model->add_constraints(); - *constraint_for_true->mutable_enforcement_literal() = - ct->enforcement_literal(); - *constraint_for_false->mutable_enforcement_literal() = - ct->enforcement_literal(); - constraint_for_true->add_enforcement_literal(boolean_linear->vars(0)); - constraint_for_false->add_enforcement_literal( - NegatedRef(boolean_linear->vars(0))); - LinearConstraintProto* linear_for_false = - constraint_for_false->mutable_linear(); - LinearConstraintProto* linear_for_true = - constraint_for_true->mutable_linear(); + ConstraintProto* ct_for_false = context_->AddEnforcedConstraint(ct); + ConstraintProto* ct_for_true = context_->AddEnforcedConstraint(ct); + + const int selector = boolean_linear->vars(0); + ct_for_true->add_enforcement_literal(selector); + ct_for_false->add_enforcement_literal(NegatedRef(selector)); + LinearConstraintProto* linear_for_false = ct_for_false->mutable_linear(); linear_for_false->add_domain(0); linear_for_false->add_domain(0); AddLinearExpressionToLinearConstraint( @@ -1927,6 +1949,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { AddLinearExpressionToLinearConstraint(ct->int_prod().target(), -1, linear_for_false); + LinearConstraintProto* linear_for_true = ct_for_true->mutable_linear(); linear_for_true->add_domain(0); linear_for_true->add_domain(0); AddLinearExpressionToLinearConstraint( @@ -1934,19 +1957,21 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { linear_for_true); AddLinearExpressionToLinearConstraint(ct->int_prod().target(), -1, linear_for_true); - context_->CanonicalizeLinearConstraint(constraint_for_false); - context_->CanonicalizeLinearConstraint(constraint_for_true); - if (PossibleIntegerOverflow(*context_->working_model, + + context_->CanonicalizeLinearConstraint(ct_for_false); + context_->CanonicalizeLinearConstraint(ct_for_true); + + if (PossibleIntegerOverflow(context_->WorkingModel(), linear_for_false->vars(), linear_for_false->coeffs()) || - PossibleIntegerOverflow(*context_->working_model, + PossibleIntegerOverflow(context_->WorkingModel(), linear_for_true->vars(), linear_for_true->coeffs())) { - context_->working_model->mutable_constraints()->RemoveLast(); - context_->working_model->mutable_constraints()->RemoveLast(); + context_->RemoveLastConstraint(); + context_->RemoveLastConstraint(); + context_->UpdateRuleStats("TODO int_prod: boolean affine term"); } else { context_->UpdateRuleStats("int_prod: boolean affine term"); - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } } @@ -1970,7 +1995,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { // This is a Boolean constraint! context_->UpdateRuleStats("int_prod: all boolean"); { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); new_ct->add_enforcement_literal(target); auto* arg = new_ct->mutable_bool_and(); @@ -1979,7 +2004,7 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { } } { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); auto* arg = new_ct->mutable_bool_or(); arg->add_literals(target); @@ -1987,7 +2012,6 @@ bool CpModelPresolver::PresolveIntProd(ConstraintProto* ct) { arg->add_literals(NegatedRef(lit)); } } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -2050,12 +2074,11 @@ bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { // Trivial case one: target = expr / +/-1. if (divisor == 1 || divisor == -1) { LinearConstraintProto* const lin = - context_->working_model->add_constraints()->mutable_linear(); + context_->AddConstraint()->mutable_linear(); lin->add_domain(0); lin->add_domain(0); AddLinearExpressionToLinearConstraint(expr, 1, lin); AddLinearExpressionToLinearConstraint(target, -divisor, lin); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("int_div: rewrite to equality"); return RemoveConstraint(ct); } @@ -2084,9 +2107,8 @@ bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { // Trivial case three: fixed_target = expr / fixed_divisor. if (context_->IsFixed(target) && - CapAdd(1, CapProd(std::abs(divisor), - 1 + std::abs(context_->FixedValue(target)))) != - std::numeric_limits::max()) { + CapAdd(1, CapProd(std::abs(divisor), 1 + std::abs(context_->FixedValue( + target)))) != kint64max) { int64_t t = context_->FixedValue(target); int64_t d = divisor; if (d < 0) { @@ -2115,16 +2137,13 @@ bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { // TODO(user): Deal with other cases where there is no change of // sign. We can also deal with target = cte, div variable. if (context_->MinOf(target) >= 0 && context_->MinOf(expr) >= 0 && - divisor > 1 && - CapProd(divisor, context_->MaxOf(target)) != - std::numeric_limits::max()) { + divisor > 1 && CapProd(divisor, context_->MaxOf(target)) != kint64max) { LinearConstraintProto* const lin = - context_->working_model->add_constraints()->mutable_linear(); + context_->AddConstraint()->mutable_linear(); lin->add_domain(0); lin->add_domain(divisor - 1); AddLinearExpressionToLinearConstraint(expr, 1, lin); AddLinearExpressionToLinearConstraint(target, -divisor, lin); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats( "int_div: linearize positive division with a constant divisor"); @@ -2138,19 +2157,30 @@ bool CpModelPresolver::PresolveIntDiv(int c, ConstraintProto* ct) { bool CpModelPresolver::PresolveIntMod(int c, ConstraintProto* ct) { if (context_->ModelIsUnsat()) return false; - // TODO(user): add support for this case. - if (HasEnforcementLiteral(*ct)) return false; // TODO(user): Presolve f(X) = g(X) % fixed_mod. const LinearExpressionProto target = ct->int_mod().target(); const LinearExpressionProto expr = ct->int_mod().exprs(0); const LinearExpressionProto mod = ct->int_mod().exprs(1); + if (context_->IsFixed(target) && context_->IsFixed(mod) && + context_->IsFixed(expr)) { + if (context_->FixedValue(expr) % context_->FixedValue(mod) == + context_->FixedValue(target)) { + context_->UpdateRuleStats("int_mod: fixed, always true"); + return RemoveConstraint(ct); + } else { + return MarkConstraintAsFalse(ct, "int_mod: fixed, always false"); + } + } + + // TODO(user): add support for this case. + if (HasEnforcementLiteral(*ct)) return false; + if (context_->MinOf(target) > 0) { bool domain_changed = false; - if (!context_->IntersectDomainWith( - expr, Domain(0, std::numeric_limits::max()), - &domain_changed)) { + if (!context_->IntersectDomainWith(expr, Domain(0, kint64max), + &domain_changed)) { return false; } if (domain_changed) { @@ -2167,9 +2197,8 @@ bool CpModelPresolver::PresolveIntMod(int c, ConstraintProto* ct) { if (context_->MaxOf(target) < 0) { bool domain_changed = false; - if (!context_->IntersectDomainWith( - expr, Domain(std::numeric_limits::min(), 0), - &domain_changed)) { + if (!context_->IntersectDomainWith(expr, Domain(kint64min, 0), + &domain_changed)) { return false; } if (domain_changed) { @@ -2463,10 +2492,9 @@ bool CpModelPresolver::RemoveSingletonInLinear(ConstraintProto* ct) { } else { indicator = context_->NewBoolVarWithConjunction(ct->enforcement_literal()); - auto* new_ct = context_->working_model->add_constraints(); + auto* new_ct = context_->AddConstraint(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); new_ct->mutable_bool_or()->add_literals(indicator); - context_->UpdateNewConstraintsVariableUsage(); } for (int i = 0; i < num_vars; ++i) { const int64_t best_value = @@ -2852,7 +2880,6 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { if (context_->HasVarValueEncoding(var, value, &encoding_lit)) { if (lit == encoding_lit) return changed; context_->AddImplication(lit, encoding_lit); - context_->UpdateNewConstraintsVariableUsage(); ct->Clear(); context_->UpdateRuleStats("linear1: transformed to implication"); return true; @@ -2862,7 +2889,6 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { // constraints linked to this variable. context_->modified_domains.Set(var); } - context_->UpdateNewConstraintsVariableUsage(); } return changed; } @@ -2874,7 +2900,6 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { if (context_->HasVarValueEncoding(var, value, &encoding_lit)) { if (NegatedRef(lit) == encoding_lit) return changed; context_->AddImplication(lit, NegatedRef(encoding_lit)); - context_->UpdateNewConstraintsVariableUsage(); ct->Clear(); context_->UpdateRuleStats("linear1: transformed to implication"); return true; @@ -2884,7 +2909,6 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { // constraints linked to this variable. context_->modified_domains.Set(var); } - context_->UpdateNewConstraintsVariableUsage(); } return changed; } @@ -2901,6 +2925,7 @@ bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { const int var2 = arg.vars(1); const int64_t coeff1 = arg.coeffs(0); const int64_t coeff2 = arg.coeffs(1); + bool changed = false; // Starts by updating our hash map of known relation. { @@ -2914,20 +2939,31 @@ bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { // Note that we don't track what constraint implied the relation, so we // cannot remove this constraint even if the relation is already known. // - // However since we only add it if the relation is not - // enforced, this should be correct. + // Even if the relation is enforced, some propagator might have detected + // that the enforcement literal can only be false when the relationship + // is satisfied, thus detecting the linear2 during probing. What we can + // say here with certainty is that the constraint is always satisfied, + // even when the enforcement literal is false, so we can remove the + // enforcement. // // Tricky: If the constraint domain is not simple, we cannot really deduce // anything. if (!ct->enforcement_literal().empty() && ct->linear().domain().size() == 2) { context_->UpdateRuleStats("linear2: already known enforced relation"); - return RemoveConstraint(ct); + ct->clear_enforcement_literal(); + const auto [known_lb, known_ub] = known_linear2_.GetBounds(expr2); + DCHECK_GE(known_lb, lb); // Guaranteed by GetStatus. + DCHECK_LE(known_ub, ub); // Guaranteed by GetStatus. + ct->mutable_linear()->set_domain(0, known_lb.value()); + ct->mutable_linear()->set_domain(1, known_ub.value()); + changed = true; } } else if (status == RelationStatus::IS_FALSE) { return MarkConstraintAsFalse(ct, "linear2: infeasible relation"); } else if (ct->enforcement_literal().empty()) { known_linear2_.Add(expr2, lb, ub); + known_model_linear2_.Add(expr2, lb, ub); if (context_->ModelIsUnsat()) return false; } } @@ -2940,25 +2976,24 @@ bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { const Domain scaled_domain2 = context_->DomainOf(var2).MultiplicationBy(coeff2, &mult2_is_exact); if (mult1_is_exact && mult2_is_exact) { - const Domain infeasible_reachable_values = - scaled_domain1.AdditionWith(scaled_domain2) - .IntersectionWith(rhs.Complement()); - - if (!infeasible_reachable_values.IsEmpty() && - infeasible_reachable_values.IsFixed()) { - return PresolveLinear2NeCst(ct, infeasible_reachable_values.FixedValue()); + // We avoid IntersectionWith(rhs.Complement()) to not allocate memory + // for problem with thousands of linear2... + std::optional unique_not_reachable = + scaled_domain1.AdditionWith(scaled_domain2).UniqueValueNotIn(rhs); + if (unique_not_reachable != std::nullopt) { + return PresolveLinear2NeCst(ct, *unique_not_reachable) || changed; } } if (rhs.IsFixed()) { if (ct->enforcement_literal().empty()) { - return PresolveUnenforcedLinear2EqCst(ct, rhs.FixedValue()); + return PresolveUnenforcedLinear2EqCst(ct, rhs.FixedValue()) || changed; } else { - return PresolveEnforcedLinear2EqCst(ct, rhs.FixedValue()); + return PresolveEnforcedLinear2EqCst(ct, rhs.FixedValue()) || changed; } } - return PresolveLinear2WithBooleans(ct); + return PresolveLinear2WithBooleans(ct) || changed; } // If it is not an equality, we only presolve the constraint if one of @@ -3010,10 +3045,9 @@ bool CpModelPresolver::PresolveLinear2WithBooleans(ConstraintProto* ct) { context_->UpdateRuleStats("linear2: boolean with one feasible value"); // => true. - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); new_ct->mutable_bool_and()->add_literals(lit); - context_->UpdateNewConstraintsVariableUsage(); // Rewrite to => var in rhs_if_true. ct->mutable_linear()->Clear(); @@ -3025,10 +3059,9 @@ bool CpModelPresolver::PresolveLinear2WithBooleans(ConstraintProto* ct) { context_->UpdateRuleStats("linear2: boolean with one feasible value"); // => false. - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); new_ct->mutable_bool_and()->add_literals(NegatedRef(lit)); - context_->UpdateNewConstraintsVariableUsage(); // Rewrite to => var in rhs_if_false. ct->mutable_linear()->Clear(); @@ -3047,7 +3080,7 @@ bool CpModelPresolver::PresolveLinear2WithBooleans(ConstraintProto* ct) { // lit => var \in rhs_if_true const Domain var_domain = context_->DomainOf(var); if (!var_domain.IsIncludedIn(rhs_if_true)) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); new_ct->add_enforcement_literal(lit); new_ct->mutable_linear()->add_vars(var); new_ct->mutable_linear()->add_coeffs(1); @@ -3057,7 +3090,7 @@ bool CpModelPresolver::PresolveLinear2WithBooleans(ConstraintProto* ct) { // NegatedRef(lit) => var \in rhs_if_false if (!var_domain.IsIncludedIn(rhs_if_false)) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); new_ct->add_enforcement_literal(NegatedRef(lit)); new_ct->mutable_linear()->add_vars(var); new_ct->mutable_linear()->add_coeffs(1); @@ -3065,7 +3098,6 @@ bool CpModelPresolver::PresolveLinear2WithBooleans(ConstraintProto* ct) { new_ct->mutable_linear()); } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -3138,7 +3170,6 @@ bool CpModelPresolver::PresolveLinear2NeCst(ConstraintProto* ct, int64_t rhs) { << ", domain_of_z = " << domain_of_z << ", num_clauses = " << num_clauses; - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("linear2: convert ax + by != cte to clauses"); return RemoveConstraint(ct); } else { @@ -3256,7 +3287,6 @@ bool CpModelPresolver::PresolveEnforcedLinear2EqCst(ConstraintProto* ct, context_->UpdateRuleStats( "linear2: implied ax + by = cte has only one solution"); - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -3325,7 +3355,6 @@ bool CpModelPresolver::PresolveEnforcedLinear2EqCst(ConstraintProto* ct, << ", coeff1 = " << coeff1 << ", coeff2 = " << coeff2 << " equal_size=" << reduced_domain.Size(); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats( "linear2: convert implied ax + by == cte to clauses"); return RemoveConstraint(ct); @@ -3355,15 +3384,15 @@ bool CpModelPresolver::PresolveEmptyLinearConstraint(ConstraintProto* ct) { } } -bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct) { +bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct, + bool canonicalize) { if (ct->constraint_case() != ConstraintProto::kLinear) return false; if (context_->ModelIsUnsat()) return false; + if (ct->linear().vars().size() > 2) return false; bool changed = false; - if (ct->linear().vars().size() <= 2) { - if (!CanonicalizeLinear(ct, &changed)) return true; - if (ct->constraint_case() != ConstraintProto::kLinear) return true; - } + if (canonicalize && !CanonicalizeLinear(ct, &changed)) return true; + if (ct->constraint_case() != ConstraintProto::kLinear) return true; if (ct->linear().vars().empty()) { return PresolveEmptyLinearConstraint(ct); @@ -3394,36 +3423,36 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { lbs[i] = context_->MinOf(linear_constraint.vars(i)); ubs[i] = context_->MaxOf(linear_constraint.vars(i)); } - DiophantineSolution diophantine_solution = SolveDiophantine( + const DiophantineSolution diophantine_sol = SolveDiophantine( linear_constraint.coeffs(), linear_constraint.domain(0), lbs, ubs); - if (!diophantine_solution.has_solutions) { + if (!diophantine_sol.has_solutions) { return MarkConstraintAsFalse(ct, "diophantine: equality has no solutions"); } - if (diophantine_solution.no_reformulation_needed) return false; + if (diophantine_sol.no_reformulation_needed) return false; // Only first coefficients of kernel_basis elements and special_solution could // overflow int64_t due to the reduction applied in SolveDiophantineEquation, - for (const std::vector& b : diophantine_solution.kernel_basis) { + for (const std::vector& b : diophantine_sol.kernel_basis) { if (!IsNegatableInt64(b[0])) { context_->UpdateRuleStats( "diophantine: couldn't apply due to int64_t overflow"); return false; } } - if (!IsNegatableInt64(diophantine_solution.special_solution[0])) { + if (!IsNegatableInt64(diophantine_sol.special_solution[0])) { context_->UpdateRuleStats( "diophantine: couldn't apply due to int64_t overflow"); return false; } const int num_replaced_variables = - static_cast(diophantine_solution.special_solution.size()); + static_cast(diophantine_sol.special_solution.size()); const int num_new_variables = - static_cast(diophantine_solution.kernel_vars_lbs.size()); + static_cast(diophantine_sol.kernel_vars_lbs.size()); DCHECK_EQ(num_new_variables + 1, num_replaced_variables); for (int i = 0; i < num_new_variables; ++i) { - if (!IsNegatableInt64(diophantine_solution.kernel_vars_lbs[i]) || - !IsNegatableInt64(diophantine_solution.kernel_vars_ubs[i])) { + if (!IsNegatableInt64(diophantine_sol.kernel_vars_lbs[i]) || + !IsNegatableInt64(diophantine_sol.kernel_vars_ubs[i])) { context_->UpdateRuleStats( "diophantine: couldn't apply due to int64_t overflow"); return false; @@ -3434,18 +3463,11 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { // We should check that the model still satisfy conditions in // `PossibleIntegerOverflow` (sat/cp_model_checker.cc) - // Create new variables. + // Initialize indices for new variables. + // We will only create them if we don't abort due to overflow.. std::vector new_variables(num_new_variables); for (int i = 0; i < num_new_variables; ++i) { - new_variables[i] = context_->working_model->variables_size(); - IntegerVariableProto* var = context_->working_model->add_variables(); - var->add_domain( - static_cast(diophantine_solution.kernel_vars_lbs[i])); - var->add_domain( - static_cast(diophantine_solution.kernel_vars_ubs[i])); - if (!ct->name().empty()) { - var->set_name(absl::StrCat("u_diophantine_", ct->name(), "_", i)); - } + new_variables[i] = context_->NumVariables() + i; } // For i = 0, ..., num_replaced_variables - 1, creates @@ -3455,64 +3477,92 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { // y[k] is the newly created variable if 0 <= k < num_new_variables // y[k] = x[index_permutation[k + 1]] otherwise. std::vector> lin_vars_lbs(num_replaced_variables); + const int old_num_constraints = context_->NumConstraints(); for (int i = 0; i < num_replaced_variables; ++i) { - ConstraintProto* identity = context_->working_model->add_constraints(); - LinearConstraintProto* lin = identity->mutable_linear(); + LinearOverflowChecker checker; + bool safe = true; + + ConstraintProto* new_ct = context_->AddEnforcedConstraint(ct); + LinearConstraintProto* lin = new_ct->mutable_linear(); if (!ct->name().empty()) { - identity->set_name(absl::StrCat("c_diophantine_", ct->name(), "_", i)); + new_ct->set_name(absl::StrCat("c_diophantine_", ct->name(), "_", i)); } - *identity->mutable_enforcement_literal() = ct->enforcement_literal(); + const int var = - linear_constraint.vars(diophantine_solution.index_permutation[i]); + linear_constraint.vars(diophantine_sol.index_permutation[i]); lin->add_vars(var); lin_vars_lbs[i].push_back(context_->MinOf(var)); lin->add_coeffs(1); - lin->add_domain( - static_cast(diophantine_solution.special_solution[i])); - lin->add_domain( - static_cast(diophantine_solution.special_solution[i])); + safe &= checker.AddTerm(1, context_->MinOf(var), context_->MaxOf(var)); + + lin->add_domain(static_cast(diophantine_sol.special_solution[i])); + lin->add_domain(static_cast(diophantine_sol.special_solution[i])); for (int j = std::max(1, i); j < num_replaced_variables; ++j) { + const int64_t lb = + static_cast(diophantine_sol.kernel_vars_lbs[j - 1]); + const int64_t ub = + static_cast(diophantine_sol.kernel_vars_ubs[j - 1]); + const int64_t coeff = + -static_cast(diophantine_sol.kernel_basis[j - 1][i]); + lin_vars_lbs[i].push_back(lb); + lin->add_vars(new_variables[j - 1]); - lin_vars_lbs[i].push_back( - static_cast(diophantine_solution.kernel_vars_lbs[j - 1])); - lin->add_coeffs( - -static_cast(diophantine_solution.kernel_basis[j - 1][i])); + lin->add_coeffs(coeff); + if (!checker.AddTerm(coeff, lb, ub)) { + safe = false; + break; + } } for (int j = num_replaced_variables; j < linear_constraint.vars_size(); ++j) { const int var = - linear_constraint.vars(diophantine_solution.index_permutation[j]); + linear_constraint.vars(diophantine_sol.index_permutation[j]); + const int64_t lb = context_->MinOf(var); + const int64_t ub = context_->MaxOf(var); + const int64_t coeff = + -static_cast(diophantine_sol.kernel_basis[j - 1][i]); + lin_vars_lbs[i].push_back(lb); lin->add_vars(var); - lin_vars_lbs[i].push_back(context_->MinOf(var)); - lin->add_coeffs( - -static_cast(diophantine_solution.kernel_basis[j - 1][i])); + lin->add_coeffs(coeff); + if (!checker.AddTerm(coeff, lb, ub)) { + safe = false; + break; + } } - - // TODO(user): The domain in the proto are not necessarily up to date so - // this might be stricter than necessary. Fix? It shouldn't matter too much - // though. - if (PossibleIntegerOverflow(*(context_->working_model), lin->vars(), - lin->coeffs())) { + if (!safe) { context_->UpdateRuleStats( "diophantine: couldn't apply due to overflowing activity of new " "constraints"); // Cancel working_model changes. - context_->working_model->mutable_constraints()->DeleteSubrange( - context_->working_model->constraints_size() - i - 1, i + 1); - context_->working_model->mutable_variables()->DeleteSubrange( - context_->working_model->variables_size() - num_new_variables, - num_new_variables); + for (int j = 0; j <= i; ++j) { + context_->RemoveLastConstraint(); + } + CHECK_EQ(old_num_constraints, context_->NumConstraints()); return false; } } + + // We are good to go. Do create the new variables now. + for (int i = 0; i < num_new_variables; ++i) { + const int64_t lb = static_cast(diophantine_sol.kernel_vars_lbs[i]); + const int64_t ub = static_cast(diophantine_sol.kernel_vars_ubs[i]); + const int var = context_->NewIntVar(Domain(lb, ub)); + context_->UpdateRuleStats("new_int_var: diophantine solution"); + CHECK_EQ(var, new_variables[i]); + if (!ct->name().empty()) { + context_->SetVarName(var, + absl::StrCat("u_diophantine_", ct->name(), "_", i)); + } + } context_->InitializeNewDomains(); + // Scan the new constraints added above in reverse order so that the hint of // `new_variables[k]` can be computed from the hint of the existing variables // and from the hints of `new_variables[k']`, with k' > k. - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); for (int i = 0; i < num_replaced_variables; ++i) { const LinearConstraintProto& linear = - context_->working_model->constraints(num_constraints - 1 - i).linear(); + context_->Constraint(num_constraints - 1 - i).linear(); DCHECK(linear.domain_size() == 2 && linear.domain(0) == linear.domain(1)); solution_crush_.SetVarToLinearConstraintSolution( ct->enforcement_literal(), std::nullopt, linear.vars(), lin_vars_lbs[i], @@ -3526,9 +3576,8 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { if (i > 0) absl::StrAppend(&log_eq, " + "); absl::StrAppend( &log_eq, - linear_constraint.coeffs(diophantine_solution.index_permutation[i]), - " x", - linear_constraint.vars(diophantine_solution.index_permutation[i])); + linear_constraint.coeffs(diophantine_sol.index_permutation[i]), " x", + linear_constraint.vars(diophantine_sol.index_permutation[i])); } if (terms_to_show < linear_constraint.vars_size()) { absl::StrAppend(&log_eq, "+ ... (", linear_constraint.vars_size(), @@ -3538,7 +3587,6 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) { } context_->UpdateRuleStats("diophantine: reformulated equality"); - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -3905,12 +3953,13 @@ void CpModelPresolver::ProcessAtMostOneAndLinear() { PresolveTimer timer(__FUNCTION__, logger_, time_limit_); ActivityBoundHelper amo_in_linear; - amo_in_linear.AddAllAtMostOnes(*context_->working_model); + amo_in_linear.AddAllAtMostOnes(context_->WorkingModel()); int num_changes = 0; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); + temp_ct_.Clear(); for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->constraint_case() != ConstraintProto::kLinear) continue; // We loop if the constraint changed. @@ -3946,10 +3995,11 @@ void CpModelPresolver::ProcessOneLinearWithAmo(int ct_index, // is NOT canonicalized. This is because we might detect equivalence here and // we will continue with ProcessOneLinearWithAmo() in the parent loop. tmp_terms_.clear(); - temp_ct_.Clear(); + DCHECK(temp_ct_.enforcement_literal().empty()); + temp_ct_.mutable_linear()->Clear(); Domain non_boolean_domain(0); const int initial_size = ct->linear().vars().size(); - int64_t min_magnitude = std::numeric_limits::max(); + int64_t min_magnitude = kint64max; int64_t max_magnitude = 0; for (int i = 0; i < initial_size; ++i) { // TODO(user): Just do not use negative reference in linear! @@ -4150,13 +4200,12 @@ void CpModelPresolver::ProcessOneLinearWithAmo(int ct_index, // To remove them, we just add them to temp_set_ and FixLiteralFromSet() // will take care of it. context_->UpdateRuleStats("linear + amo: added implications"); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); for (const int lit : must_be_true) { new_ct->mutable_bool_and()->add_literals(lit); temp_set_.insert(lit); } - context_->UpdateNewConstraintsVariableUsage(); } const int num_fixed = FixLiteralFromSet(temp_set_, ct->mutable_linear()); @@ -4171,7 +4220,7 @@ void CpModelPresolver::ProcessOneLinearWithAmo(int ct_index, if (ct->linear().vars().empty()) { context_->UpdateRuleStats("linear + amo: empty after processing"); - PresolveSmallLinear(ct); + PresolveEmptyLinearConstraint(ct); context_->UpdateConstraintVariableUsage(ct_index); return; } @@ -4416,13 +4465,12 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, break; } if (c == ct_index) continue; - if (context_->working_model->constraints(c).constraint_case() != + if (context_->Constraint(c).constraint_case() != ConstraintProto::kLinear) { abort = true; break; } - for (const int ref : - context_->working_model->constraints(c).enforcement_literal()) { + for (const int ref : context_->Constraint(c).enforcement_literal()) { if (PositiveRef(ref) == var) { abort = true; break; @@ -4445,8 +4493,7 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, // TODO(user): We still substitute even if this happens in the objective // though. Is that good? if (others.size() != 1) continue; - const ConstraintProto& other_ct = - context_->working_model->constraints(others.front()); + const ConstraintProto& other_ct = context_->Constraint(others.front()); if (!other_ct.enforcement_literal().empty()) continue; // Compute the implied domain using the other constraint. @@ -4499,16 +4546,15 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, // TODO(user): The copy is needed to have a simpler overflow-checking // code were we check once the substitution is done. If needed we could // optimize that, but with more code. - copy_if_we_abort = context_->working_model->constraints(c); + copy_if_we_abort = context_->Constraint(c); // In some corner cases, this might violate our overflow precondition or // even create an overflow. The danger is limited since the range of the // linear expression used in the definition do not exceed the domain of // the variable we substitute. But this is not the case for the doubleton // case above. - if (!SubstituteVariable( - var, var_coeff, *ct, - context_->working_model->mutable_constraints(c))) { + if (!SubstituteVariable(var, var_coeff, *ct, + context_->MutableConstraint(c))) { // The function above can fail because of overflow, but also if the // constraint was not canonicalized yet and the variable is actually not // there (we have var - var for instance). @@ -4516,8 +4562,7 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, // TODO(user): we canonicalize it right away, but I am not sure it is // really needed. bool changed = false; - if (!CanonicalizeLinear(context_->working_model->mutable_constraints(c), - &changed)) { + if (!CanonicalizeLinear(context_->MutableConstraint(c), &changed)) { return true; } if (changed) { @@ -4527,12 +4572,11 @@ bool CpModelPresolver::PropagateDomainsInLinear(int ct_index, break; } - if (PossibleIntegerOverflow( - *context_->working_model, - context_->working_model->constraints(c).linear().vars(), - context_->working_model->constraints(c).linear().coeffs())) { + if (PossibleIntegerOverflow(context_->WorkingModel(), + context_->Constraint(c).linear().vars(), + context_->Constraint(c).linear().coeffs())) { // Revert the change in this case. - *context_->working_model->mutable_constraints(c) = copy_if_we_abort; + *context_->MutableConstraint(c) = copy_if_we_abort; abort = true; break; } @@ -4747,7 +4791,7 @@ void CpModelPresolver::ExtractEnforcementLiteralFromLinearConstraint( int64_t min_sum = 0; int64_t max_sum = 0; int64_t max_coeff_magnitude = 0; - int64_t min_coeff_magnitude = std::numeric_limits::max(); + int64_t min_coeff_magnitude = kint64max; for (int i = 0; i < num_vars; ++i) { const int ref = arg.vars(i); const int64_t coeff = arg.coeffs(i); @@ -4824,7 +4868,7 @@ void CpModelPresolver::ExtractEnforcementLiteralFromLinearConstraint( if (max_coeff_magnitude < std::max(ub_threshold, lb_threshold)) return; context_->UpdateRuleStats("linear: split boxed constraint"); - ConstraintProto* new_ct1 = context_->working_model->add_constraints(); + ConstraintProto* new_ct1 = context_->AddConstraint(); *new_ct1 = *ct; if (!ct->name().empty()) { new_ct1->set_name(absl::StrCat(ct->name(), " (part 1)")); @@ -4832,7 +4876,7 @@ void CpModelPresolver::ExtractEnforcementLiteralFromLinearConstraint( FillDomainInProto(Domain(min_sum, rhs_domain.Max()), new_ct1->mutable_linear()); - ConstraintProto* new_ct2 = context_->working_model->add_constraints(); + ConstraintProto* new_ct2 = context_->AddConstraint(); *new_ct2 = *ct; if (!ct->name().empty()) { new_ct2->set_name(absl::StrCat(ct->name(), " (part 2)")); @@ -4840,7 +4884,6 @@ void CpModelPresolver::ExtractEnforcementLiteralFromLinearConstraint( FillDomainInProto(rhs_domain.UnionWith(Domain(rhs_domain.Max(), max_sum)), new_ct2->mutable_linear()); - context_->UpdateNewConstraintsVariableUsage(); ct->Clear(); context_->UpdateConstraintVariableUsage(ct_index); return; @@ -4955,7 +4998,6 @@ void CpModelPresolver::ExtractEnforcementLiteralFromLinearConstraint( FillDomainInProto(rhs_domain.AdditionWith(Domain(rhs_offset)), mutable_arg); if (some_integer_encoding_were_extracted || new_size == 1) { context_->UpdateConstraintVariableUsage(ct_index); - context_->UpdateNewConstraintsVariableUsage(); } } @@ -5002,12 +5044,11 @@ void CpModelPresolver::ExtractAtMostOneFromLinear(ConstraintProto* ct) { } else { context_->UpdateRuleStats("linear: extracted at most one (min)"); } - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); new_ct->set_name(ct->name()); for (const int ref : at_most_one) { new_ct->mutable_at_most_one()->add_literals(ref); } - context_->UpdateNewConstraintsVariableUsage(); } } } @@ -5025,7 +5066,7 @@ bool CpModelPresolver::PresolveLinearOnBooleans(ConstraintProto* ct) { const LinearConstraintProto& arg = ct->linear(); const int num_vars = arg.vars_size(); - int64_t min_coeff = std::numeric_limits::max(); + int64_t min_coeff = kint64max; int64_t max_coeff = 0; int64_t min_sum = 0; int64_t max_sum = 0; @@ -5180,13 +5221,12 @@ bool CpModelPresolver::PresolveLinearOnBooleans(ConstraintProto* ct) { min_sum + max_coeff <= rhs_domain.Max()) { // TODO(user): Support enforced exactly one. context_->UpdateRuleStats("linear: positive equal one"); - ConstraintProto* exactly_one = context_->working_model->add_constraints(); + ConstraintProto* exactly_one = context_->AddConstraint(); exactly_one->set_name(ct->name()); for (int i = 0; i < num_vars; ++i) { exactly_one->mutable_exactly_one()->add_literals( arg.coeffs(i) > 0 ? arg.vars(i) : NegatedRef(arg.vars(i))); } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } else if (!HasEnforcementLiteral(*ct) && rhs_domain.NumIntervals() == 1 && max_sum > rhs_domain.Max() && @@ -5195,13 +5235,12 @@ bool CpModelPresolver::PresolveLinearOnBooleans(ConstraintProto* ct) { max_sum - max_coeff >= rhs_domain.Min()) { // TODO(user): Support enforced exactly one. context_->UpdateRuleStats("linear: negative equal one"); - ConstraintProto* exactly_one = context_->working_model->add_constraints(); + ConstraintProto* exactly_one = context_->AddConstraint(); exactly_one->set_name(ct->name()); for (int i = 0; i < num_vars; ++i) { exactly_one->mutable_exactly_one()->add_literals( arg.coeffs(i) > 0 ? NegatedRef(arg.vars(i)) : arg.vars(i)); } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -5291,7 +5330,6 @@ bool CpModelPresolver::PresolveLinearOnBooleans(ConstraintProto* ct) { } } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -5308,9 +5346,8 @@ bool CpModelPresolver::PresolveInterval(int c, ConstraintProto* ct) { if (ct->enforcement_literal().empty()) { bool domain_changed = false; // Size can't be negative. - if (!context_->IntersectDomainWith( - interval->size(), Domain(0, std::numeric_limits::max()), - &domain_changed)) { + if (!context_->IntersectDomainWith(interval->size(), Domain(0, kint64max), + &domain_changed)) { return false; } if (domain_changed) { @@ -5321,8 +5358,7 @@ bool CpModelPresolver::PresolveInterval(int c, ConstraintProto* ct) { // Note that the linear relation is stored elsewhere, so it is safe to just // remove such special interval constraint. - if (context_->ConstraintVariableGraphIsUpToDate() && - context_->IntervalUsage(c) == 0) { + if (context_->IntervalUsage(c) == 0) { context_->UpdateRuleStats("intervals: removed unused interval"); return RemoveConstraint(ct); } @@ -5336,20 +5372,21 @@ bool CpModelPresolver::PresolveInterval(int c, ConstraintProto* ct) { // TODO(user): avoid code duplication between expand and presolve. bool CpModelPresolver::PresolveInverse(ConstraintProto* ct) { + DCHECK(ct->inverse().f_direct().empty()); // TODO(user): add support for this case. if (HasEnforcementLiteral(*ct)) return false; - const int size = ct->inverse().f_direct().size(); + const int size = ct->inverse().f_expr_direct().size(); bool changed = false; // Make sure the domains are included in [0, size - 1). - for (const int ref : ct->inverse().f_direct()) { - if (!context_->IntersectDomainWith(ref, Domain(0, size - 1), &changed)) { + for (const LinearExpressionProto& expr : ct->inverse().f_expr_direct()) { + if (!context_->IntersectDomainWith(expr, Domain(0, size - 1), &changed)) { VLOG(1) << "Empty domain for a variable in PresolveInverse()"; return false; } } - for (const int ref : ct->inverse().f_inverse()) { - if (!context_->IntersectDomainWith(ref, Domain(0, size - 1), &changed)) { + for (const LinearExpressionProto& expr : ct->inverse().f_expr_inverse()) { + if (!context_->IntersectDomainWith(expr, Domain(0, size - 1), &changed)) { VLOG(1) << "Empty domain for a variable in PresolveInverse()"; return false; } @@ -5361,71 +5398,102 @@ bool CpModelPresolver::PresolveInverse(ConstraintProto* ct) { // appear together. { absl::flat_hash_set direct_vars; - for (const int ref : ct->inverse().f_direct()) { - const auto [it, inserted] = direct_vars.insert(PositiveRef(ref)); + for (const LinearExpressionProto& expr : ct->inverse().f_expr_direct()) { + DCHECK_LE(expr.vars_size(), 1); + if (expr.vars().empty()) continue; + if (abs(expr.coeffs(0)) != 1 || expr.offset() != 0) continue; + const int var = expr.vars(0); + DCHECK_GE(var, 0); + const auto [it, inserted] = direct_vars.insert(var); if (!inserted) { return context_->NotifyThatModelIsUnsat("inverse: duplicated variable"); } } absl::flat_hash_set inverse_vars; - for (const int ref : ct->inverse().f_inverse()) { - const auto [it, inserted] = inverse_vars.insert(PositiveRef(ref)); + for (const LinearExpressionProto& expr : ct->inverse().f_expr_inverse()) { + DCHECK_LE(expr.vars_size(), 1); + if (expr.vars().empty()) continue; + if (abs(expr.coeffs(0)) != 1 || expr.offset() != 0) continue; + const int var = expr.vars(0); + DCHECK_GE(var, 0); + const auto [it, inserted] = inverse_vars.insert(var); if (!inserted) { return context_->NotifyThatModelIsUnsat("inverse: duplicated variable"); } } } - // Propagate from one vector to its counterpart. - // Note this reaches the fixpoint as there is a one to one mapping between - // (variable-value) pairs in each vector. - const auto filter_inverse_domain = - [this, size, &changed](const auto& direct, const auto& inverse) { - // Build the set of values in the inverse vector. - std::vector> inverse_values(size); - for (int i = 0; i < size; ++i) { - const Domain domain = context_->DomainOf(inverse[i]); - for (const int64_t j : domain.Values()) { - inverse_values[i].insert(j); - } + // Propagate from one vector to its counterpart until fix point. + const auto filter_inverse_domain = [this, size, &changed](const auto& direct, + const auto& inverse, + bool& filtered) { + // Build the set of values in the inverse vector. + std::vector> inverse_values(size); + for (int i = 0; i < size; ++i) { + if (inverse[i].vars().empty()) { + inverse_values[i].insert(inverse[i].offset()); + } else { + for (const int64_t var_value : + context_->DomainOf(inverse[i].vars(0)).Values()) { + const int64_t j = + var_value * inverse[i].coeffs(0) + inverse[i].offset(); + inverse_values[i].insert(j); + } + } + } + + // Propagate from the inverse vector to the direct vector. Reduce the + // domains of each variable in the direct vector by checking that the + // inverse value exists. + std::vector possible_values; + for (int i = 0; i < size; ++i) { + possible_values.clear(); + bool removed_value = false; + if (direct[i].vars().empty()) { + const int64_t j = direct[i].offset(); + if (inverse_values[j].contains(i)) { + possible_values.push_back(j); + } else { + removed_value = true; } - - // Propagate from the inverse vector to the direct vector. Reduce the - // domains of each variable in the direct vector by checking that the - // inverse value exists. - std::vector possible_values; - for (int i = 0; i < size; ++i) { - possible_values.clear(); - const Domain domain = context_->DomainOf(direct[i]); - bool removed_value = false; - for (const int64_t j : domain.Values()) { - if (inverse_values[j].contains(i)) { - possible_values.push_back(j); - } else { - removed_value = true; - } - } - if (removed_value) { - changed = true; - if (!context_->IntersectDomainWith( - direct[i], Domain::FromValues(possible_values))) { - VLOG(1) << "Empty domain for a variable in PresolveInverse()"; - return false; - } + } else { + for (const int64_t var_value : + context_->DomainOf(direct[i].vars(0)).Values()) { + const int64_t j = + var_value * direct[i].coeffs(0) + direct[i].offset(); + if (inverse_values[j].contains(i)) { + possible_values.push_back(j); + } else { + removed_value = true; } } - return true; - }; + } + if (removed_value) { + changed = true; + filtered = true; + if (!context_->IntersectDomainWith( + direct[i], Domain::FromValues(possible_values))) { + VLOG(1) << "Empty domain for a variable in PresolveInverse()"; + return false; + } + } + } + return true; + }; - if (!filter_inverse_domain(ct->inverse().f_direct(), - ct->inverse().f_inverse())) { - return false; - } + for (int i = 0; i < 100; ++i) { // Just to avoid potentially bad cases. + bool filtered = false; + if (!filter_inverse_domain(ct->inverse().f_expr_direct(), + ct->inverse().f_expr_inverse(), filtered)) { + return false; + } - if (!filter_inverse_domain(ct->inverse().f_inverse(), - ct->inverse().f_direct())) { - return false; + if (!filter_inverse_domain(ct->inverse().f_expr_inverse(), + ct->inverse().f_expr_direct(), filtered)) { + return false; + } + if (!filtered) break; } if (changed) { @@ -5475,7 +5543,7 @@ bool CpModelPresolver::PresolveElement(int c, ConstraintProto* ct) { // Special case if the index is fixed. if (context_->IsFixed(index)) { const int64_t index_value = context_->FixedValue(index); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); new_ct->mutable_linear()->add_domain(0); new_ct->mutable_linear()->add_domain(0); AddLinearExpressionToLinearConstraint(target, 1, new_ct->mutable_linear()); @@ -5487,7 +5555,6 @@ bool CpModelPresolver::PresolveElement(int c, ConstraintProto* ct) { return context_->NotifyThatModelIsUnsat( "element: impossible fixed index"); } - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("element: fixed index"); return RemoveConstraint(ct); } @@ -5567,7 +5634,7 @@ bool CpModelPresolver::PresolveElement(int c, ConstraintProto* ct) { context_->UpdateRuleStats("element: reduced index domain "); // If the index is fixed, this is a equality constraint. if (context_->IsFixed(index)) { - ConstraintProto* const eq = context_->working_model->add_constraints(); + ConstraintProto* const eq = context_->AddConstraint(); eq->mutable_linear()->add_domain(0); eq->mutable_linear()->add_domain(0); AddLinearExpressionToLinearConstraint(target, 1, eq->mutable_linear()); @@ -5575,7 +5642,6 @@ bool CpModelPresolver::PresolveElement(int c, ConstraintProto* ct) { ct->element().exprs(context_->FixedValue(index)), -1, eq->mutable_linear()); context_->CanonicalizeLinearConstraint(eq); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("element: fixed index"); return RemoveConstraint(ct); } @@ -5675,7 +5741,7 @@ bool CpModelPresolver::PresolveElement(int c, ConstraintProto* ct) { const int64_t offset = CapSub(CapProd(first_target_var_value, d_index), CapProd(first_index_var_value, d_target)); if (!AtMinOrMaxInt64(offset)) { - ConstraintProto* const lin = context_->working_model->add_constraints(); + ConstraintProto* const lin = context_->AddConstraint(); lin->mutable_linear()->add_vars(target.vars(0)); lin->mutable_linear()->add_coeffs(d_index); lin->mutable_linear()->add_vars(index_var); @@ -5683,7 +5749,6 @@ bool CpModelPresolver::PresolveElement(int c, ConstraintProto* ct) { lin->mutable_linear()->add_domain(offset); lin->mutable_linear()->add_domain(offset); context_->CanonicalizeLinearConstraint(lin); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("element: rewrite as affine constraint"); return RemoveConstraint(ct); } @@ -6079,28 +6144,27 @@ namespace { // Add the constraint (lhs => rhs) to the given proto. The hash map lhs -> // bool_and constraint index is used to merge implications with the same lhs. -void AddImplication(int lhs, int rhs, CpModelProto* proto, - absl::flat_hash_map* ref_to_bool_and) { +void AddImplicationWithMerging(int lhs, int rhs, PresolveContext* context, + absl::flat_hash_map* ref_to_bool_and) { if (ref_to_bool_and->contains(lhs)) { const int ct_index = (*ref_to_bool_and)[lhs]; - proto->mutable_constraints(ct_index)->mutable_bool_and()->add_literals(rhs); + context->MutableConstraint(ct_index)->mutable_bool_and()->add_literals(rhs); } else if (ref_to_bool_and->contains(NegatedRef(rhs))) { const int ct_index = (*ref_to_bool_and)[NegatedRef(rhs)]; - proto->mutable_constraints(ct_index)->mutable_bool_and()->add_literals( + context->MutableConstraint(ct_index)->mutable_bool_and()->add_literals( NegatedRef(lhs)); } else { - (*ref_to_bool_and)[lhs] = proto->constraints_size(); - ConstraintProto* ct = proto->add_constraints(); + (*ref_to_bool_and)[lhs] = context->NumConstraints(); + ConstraintProto* ct = context->AddConstraint(); ct->add_enforcement_literal(lhs); ct->mutable_bool_and()->add_literals(rhs); } } template -void ExtractClauses(bool merge_into_bool_and, - absl::Span index_mapping, - const ClauseContainer& container, CpModelProto* proto, - std::string_view debug_name = "") { +void ExtractClausesToContext(absl::Span index_mapping, + const ClauseContainer& container, + PresolveContext* context) { // We regroup the "implication" into bool_and to have a more concise proto and // also for nicer information about the number of binary clauses. // @@ -6115,16 +6179,47 @@ void ExtractClauses(bool merge_into_bool_and, // bool_and. // // TODO(user): Be smarter in how we regroup clause of size 2? - if (merge_into_bool_and && clause.size() == 2) { + if (clause.size() == 2) { const int var_a = index_mapping[clause[0].Variable().value()]; const int var_b = index_mapping[clause[1].Variable().value()]; const int ref_a = clause[0].IsPositive() ? var_a : NegatedRef(var_a); const int ref_b = clause[1].IsPositive() ? var_b : NegatedRef(var_b); - AddImplication(NegatedRef(ref_a), ref_b, proto, &ref_to_bool_and); + AddImplicationWithMerging(NegatedRef(ref_a), ref_b, context, + &ref_to_bool_and); continue; } // bool_or. + ConstraintProto* ct = context->AddConstraint(); + ct->mutable_bool_or()->mutable_literals()->Reserve(clause.size()); + for (const Literal l : clause) { + const int var = index_mapping[l.Variable().value()]; + if (l.IsPositive()) { + ct->mutable_bool_or()->add_literals(var); + } else { + ct->mutable_bool_or()->add_literals(NegatedRef(var)); + } + } + } + + DCHECK(context->ConstraintVariableUsageIsConsistent()); +} + +void ExtractClausesToMappingModelProto(absl::Span index_mapping, + const SatPostsolver& container, + CpModelProto* proto) { + const std::string debug_name = + absl::GetFlag(FLAGS_cp_model_debug_postsolve) ? "sat_postsolver" : ""; + + // We regroup the "implication" into bool_and to have a more concise proto and + // also for nicer information about the number of binary clauses. + // + // Important: however, we do not do that for the model used during postsolving + // since the order of the constraints might be important there depending on + // how we perform the postsolve. + for (int i = 0; i < container.NumClauses(); ++i) { + const auto& clause = container.Clause(i); + if (clause.empty()) continue; ConstraintProto* ct = proto->add_constraints(); if (!debug_name.empty()) { ct->set_name(debug_name); @@ -6200,7 +6295,7 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { if (!visited_intervals.insert(interval_index).second) continue; ConstraintProto* interval_ct = - context_->working_model->mutable_constraints(interval_index); + context_->MutableConstraint(interval_index); // Case 1: size > 0. Interval must be unperformed. if (context_->SizeMin(interval_index) > 0) { @@ -6232,8 +6327,7 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { // intervals still cannot overlap with other intervals. } else { // Case 3: interval is optional and size can be > 0. const int performed_literal = interval_ct->enforcement_literal(0); - ConstraintProto* size_eq_zero = - context_->working_model->add_constraints(); + ConstraintProto* size_eq_zero = context_->AddConstraint(); size_eq_zero->add_enforcement_literal(performed_literal); size_eq_zero->mutable_linear()->add_domain(0); size_eq_zero->mutable_linear()->add_domain(0); @@ -6243,7 +6337,6 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { context_->UpdateRuleStats( "no_overlap: make duplicate intervals as unperformed or zero " "sized"); - context_->UpdateNewConstraintsVariableUsage(); } } } @@ -6260,6 +6353,7 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { // Split constraints in disjoint sets. if (proto->intervals_size() > 1) { std::vector indexed_intervals; + indexed_intervals.reserve(proto->intervals_size()); for (int i = 0; i < proto->intervals().size(); ++i) { const int index = proto->intervals(i); indexed_intervals.push_back({index, @@ -6274,22 +6368,20 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { if (intervals.size() <= 1) continue; NoOverlapConstraintProto* new_no_overlap = - context_->working_model->add_constraints()->mutable_no_overlap(); + context_->AddConstraint()->mutable_no_overlap(); // Fill in the intervals. Unfortunately, the Assign() method does not // compile in or-tools. for (const int i : intervals) { new_no_overlap->add_intervals(i); } } - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("no_overlap: split into disjoint components"); return RemoveConstraint(ct); } } std::vector constant_intervals; - int64_t size_min_of_non_constant_intervals = - std::numeric_limits::max(); + int64_t size_min_of_non_constant_intervals = kint64max; for (int i = 0; i < proto->intervals_size(); ++i) { const int interval_index = proto->intervals(i); if (context_->IntervalIsConstant(interval_index)) { @@ -6347,9 +6439,9 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { } const int64_t new_start = context_->StartMin(constant_intervals[start]); const int64_t new_end = context_->EndMax(constant_intervals[i]); - proto->add_intervals(context_->working_model->constraints_size()); + proto->add_intervals(context_->NumConstraints()); IntervalConstraintProto* new_interval = - context_->working_model->add_constraints()->mutable_interval(); + context_->AddConstraint()->mutable_interval(); new_interval->mutable_start()->set_offset(new_start); new_interval->mutable_size()->set_offset(new_end - new_start); new_interval->mutable_end()->set_offset(new_end); @@ -6374,7 +6466,6 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { intervals_to_remove.clear(); constant_intervals.clear(); changed = true; - context_->UpdateNewConstraintsVariableUsage(); } } @@ -6383,8 +6474,7 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { int num_size_zero_or_one = 0; bool has_optional_size_one = false; for (const int index : proto->intervals()) { - const ConstraintProto& interval_ct = - context_->working_model->constraints(index); + const ConstraintProto& interval_ct = context_->Constraint(index); const LinearExpressionProto& size = interval_ct.interval().size(); if (size.vars().empty() && size.offset() >= 0 && size.offset() <= 1) { ++num_size_zero_or_one; @@ -6402,7 +6492,7 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { int new_size = 0; for (const int index : proto->intervals()) { const IntervalConstraintProto& interval = - context_->working_model->constraints(index).interval(); + context_->Constraint(index).interval(); if (interval.size().offset() == 0) continue; proto->set_intervals(new_size++, index); } @@ -6418,14 +6508,13 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { context_->AddEnforcedConstraint(ct)->mutable_all_diff(); for (const int index : proto->intervals()) { const IntervalConstraintProto& interval = - context_->working_model->constraints(index).interval(); + context_->Constraint(index).interval(); if (interval.size().offset() == 0) continue; *all_diff->add_exprs() = interval.start(); } if (all_diff->exprs_size() < initial_num_intervals) { context_->UpdateRuleStats("no_overlap: removed size 0 from all diff"); } - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("no_overlap: converted to all diff"); return RemoveConstraint(ct); } @@ -6446,8 +6535,7 @@ bool CpModelPresolver::PresolveNoOverlap(ConstraintProto* ct) { // intervals. if (move_constraint_last) { changed = true; - *context_->working_model->add_constraints() = *ct; - context_->UpdateNewConstraintsVariableUsage(); + *context_->AddConstraint() = *ct; return RemoveConstraint(ct); } @@ -6494,16 +6582,10 @@ bool CpModelPresolver::PresolveNoOverlap2DFramed( } const int x_interval_index = proto.x_intervals(box.box_index); const int y_interval_index = proto.y_intervals(box.box_index); - if (!context_->working_model->constraints(x_interval_index) - .enforcement_literal() - .empty() && - !context_->working_model->constraints(y_interval_index) - .enforcement_literal() - .empty()) { - if (context_->working_model->constraints(x_interval_index) - .enforcement_literal(0) != - context_->working_model->constraints(y_interval_index) - .enforcement_literal(0)) { + if (!context_->Constraint(x_interval_index).enforcement_literal().empty() && + !context_->Constraint(y_interval_index).enforcement_literal().empty()) { + if (context_->Constraint(x_interval_index).enforcement_literal(0) != + context_->Constraint(y_interval_index).enforcement_literal(0)) { // Two different enforcement literals. return false; } @@ -6520,9 +6602,9 @@ bool CpModelPresolver::PresolveNoOverlap2DFramed( const int x_interval_index = proto.x_intervals(box_index); const int y_interval_index = proto.y_intervals(box_index); const ConstraintProto& x_interval_ct = - context_->working_model->constraints(x_interval_index); + context_->Constraint(x_interval_index); const ConstraintProto& y_interval_ct = - context_->working_model->constraints(y_interval_index); + context_->Constraint(y_interval_index); if (x_interval_ct.enforcement_literal().empty() && y_interval_ct.enforcement_literal().empty()) { // Mandatory box, update the domains. @@ -6582,14 +6664,11 @@ bool CpModelPresolver::PresolveNoOverlap2DFramed( } } } else if (enforcement_literals_for_amo.size() > 1) { - context_->working_model->add_constraints() - ->mutable_at_most_one() - ->mutable_literals() - ->Add(enforcement_literals_for_amo.begin(), - enforcement_literals_for_amo.end()); + context_->AddConstraint()->mutable_at_most_one()->mutable_literals()->Add( + enforcement_literals_for_amo.begin(), + enforcement_literals_for_amo.end()); } context_->UpdateRuleStats("no_overlap_2d: at most one rectangle in region"); - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -6611,9 +6690,9 @@ bool CpModelPresolver::ExpandEncoded2DBinPacking( for (int idx = 0; idx < non_fixed_boxes.size(); ++idx) { const int b = non_fixed_boxes[idx].box_index; const ConstraintProto& x_interval_ct = - context_->working_model->constraints(proto.x_intervals(b)); + context_->Constraint(proto.x_intervals(b)); const ConstraintProto& y_interval_ct = - context_->working_model->constraints(proto.y_intervals(b)); + context_->Constraint(proto.y_intervals(b)); if (x_interval_ct.enforcement_literal().empty() && y_interval_ct.enforcement_literal().empty()) { // Mandatory box, cannot be in the "absent" bin -1. @@ -6635,8 +6714,7 @@ bool CpModelPresolver::ExpandEncoded2DBinPacking( context_->NewBoolVarWithConjunction( {enforcement_literal, potentially_other_enforcement_literal}); - BoolArgumentProto* bool_or = - context_->working_model->add_constraints()->mutable_bool_or(); + BoolArgumentProto* bool_or = context_->AddConstraint()->mutable_bool_or(); bool_or->add_literals(NegatedRef(interval_is_absent_literal)); for (const int lit : {enforcement_literal, potentially_other_enforcement_literal}) { @@ -6653,18 +6731,15 @@ bool CpModelPresolver::ExpandEncoded2DBinPacking( disjoint_packing_presolve_result.bins[bin_index]; NoOverlap2DConstraintProto new_no_overlap_2d; for (const Rectangle& ret : bin.fixed_boxes) { - new_no_overlap_2d.add_x_intervals( - context_->working_model->constraints_size()); - new_no_overlap_2d.add_y_intervals( - context_->working_model->constraints_size() + 1); + new_no_overlap_2d.add_x_intervals(context_->NumConstraints()); + new_no_overlap_2d.add_y_intervals(context_->NumConstraints() + 1); IntervalConstraintProto* new_interval = - context_->working_model->add_constraints()->mutable_interval(); + context_->AddConstraint()->mutable_interval(); new_interval->mutable_start()->set_offset(ret.x_min.value()); new_interval->mutable_size()->set_offset(ret.SizeX().value()); new_interval->mutable_end()->set_offset(ret.x_max.value()); - new_interval = - context_->working_model->add_constraints()->mutable_interval(); + new_interval = context_->AddConstraint()->mutable_interval(); new_interval->mutable_start()->set_offset(ret.y_min.value()); new_interval->mutable_size()->set_offset(ret.SizeY().value()); new_interval->mutable_end()->set_offset(ret.y_max.value()); @@ -6677,25 +6752,21 @@ bool CpModelPresolver::ExpandEncoded2DBinPacking( .area_index = bin_index, .literal = presence_in_box_lit}); const ConstraintProto& x_interval_ct = - context_->working_model->constraints(proto.x_intervals(b)); + context_->Constraint(proto.x_intervals(b)); const ConstraintProto& y_interval_ct = - context_->working_model->constraints(proto.y_intervals(b)); - ConstraintProto* new_interval_x = - context_->working_model->add_constraints(); + context_->Constraint(proto.y_intervals(b)); + ConstraintProto* new_interval_x = context_->AddConstraint(); *new_interval_x = x_interval_ct; new_interval_x->clear_enforcement_literal(); new_interval_x->add_enforcement_literal(presence_in_box_lit); - ConstraintProto* new_interval_y = - context_->working_model->add_constraints(); + ConstraintProto* new_interval_y = context_->AddConstraint(); *new_interval_y = y_interval_ct; new_interval_y->clear_enforcement_literal(); new_interval_y->add_enforcement_literal(presence_in_box_lit); - new_no_overlap_2d.add_x_intervals( - context_->working_model->constraints_size() - 2); - new_no_overlap_2d.add_y_intervals( - context_->working_model->constraints_size() - 1); + new_no_overlap_2d.add_x_intervals(context_->NumConstraints() - 2); + new_no_overlap_2d.add_y_intervals(context_->NumConstraints() - 1); } - context_->working_model->add_constraints()->mutable_no_overlap_2d()->Swap( + context_->AddConstraint()->mutable_no_overlap_2d()->Swap( &new_no_overlap_2d); } @@ -6707,8 +6778,7 @@ bool CpModelPresolver::ExpandEncoded2DBinPacking( return context_->NotifyThatModelIsUnsat( "A mandatory box cannot be placed in any position"); } - auto* exactly_one = - context_->working_model->add_constraints()->mutable_exactly_one(); + auto* exactly_one = context_->AddConstraint()->mutable_exactly_one(); for (const int presence_literal : presence_literals) { exactly_one->add_literals(presence_literal); } @@ -6719,9 +6789,8 @@ bool CpModelPresolver::ExpandEncoded2DBinPacking( areas.Add(disjoint_packing_presolve_result.bins[bin_index].bin_area); } solution_crush_.AssignVariableToPackingArea( - areas, *context_->working_model, proto.x_intervals(), proto.y_intervals(), + areas, context_->WorkingModel(), proto.x_intervals(), proto.y_intervals(), box_in_area_lits); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats( "no_overlap_2d: fixed boxes partition available space, converted " "to optional regions"); @@ -6754,7 +6823,7 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { if (context_->StartMin(interval_index) > context_->EndMax(interval_index)) { const ConstraintProto& interval_ct = - context_->working_model->constraints(interval_index); + context_->Constraint(interval_index); if (interval_ct.enforcement_literal_size() == 1) { const int literal = interval_ct.enforcement_literal(0); if (!context_->SetLiteralToFalse(literal)) { @@ -6768,14 +6837,13 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { if (context_->SizeMin(interval_index) < 0) { const ConstraintProto& interval_ct = - context_->working_model->constraints(interval_index); + context_->Constraint(interval_index); if (interval_ct.enforcement_literal().empty()) { bool domain_changed = false; // Size can't be negative. - if (!context_->IntersectDomainWith( - interval_ct.interval().size(), - Domain(0, std::numeric_limits::max()), - &domain_changed)) { + if (!context_->IntersectDomainWith(interval_ct.interval().size(), + Domain(0, kint64max), + &domain_changed)) { return false; } } @@ -6876,13 +6944,12 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { if (boxes.size() <= 1) continue; NoOverlap2DConstraintProto* new_no_overlap_2d = - context_->working_model->add_constraints()->mutable_no_overlap_2d(); + context_->AddConstraint()->mutable_no_overlap_2d(); for (const int b : boxes) { new_no_overlap_2d->add_x_intervals(proto.x_intervals(b)); new_no_overlap_2d->add_y_intervals(proto.y_intervals(b)); } } - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("no_overlap_2d: split into disjoint components"); return RemoveConstraint(ct); } @@ -6905,14 +6972,13 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { IndexedInterval::ComparatorByStart()); ConstructOverlappingSets(absl::MakeSpan(indexed_intervals), &no_overlaps); for (int i = 0; i < no_overlaps.size(); ++i) { - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); // Unfortunately, the Assign() method does not work in or-tools as the // protobuf int32_t type is not the int type. for (const int i : no_overlaps[i]) { new_ct->mutable_no_overlap()->add_intervals(i); } } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } @@ -6947,18 +7013,15 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { new_no_overlap_2d.add_y_intervals(proto.y_intervals(i)); } for (const Rectangle& fixed_box : fixed_boxes) { - const int item_x_interval = - context_->working_model->constraints().size(); + const int item_x_interval = context_->NumConstraints(); IntervalConstraintProto* new_interval = - context_->working_model->add_constraints()->mutable_interval(); + context_->AddConstraint()->mutable_interval(); new_interval->mutable_start()->set_offset(fixed_box.x_min.value()); new_interval->mutable_size()->set_offset(fixed_box.SizeX().value()); new_interval->mutable_end()->set_offset(fixed_box.x_max.value()); - const int item_y_interval = - context_->working_model->constraints().size(); - new_interval = - context_->working_model->add_constraints()->mutable_interval(); + const int item_y_interval = context_->NumConstraints(); + new_interval = context_->AddConstraint()->mutable_interval(); new_interval->mutable_start()->set_offset(fixed_box.y_min.value()); new_interval->mutable_size()->set_offset(fixed_box.SizeY().value()); new_interval->mutable_end()->set_offset(fixed_box.y_max.value()); @@ -6966,9 +7029,8 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { new_no_overlap_2d.add_x_intervals(item_x_interval); new_no_overlap_2d.add_y_intervals(item_y_interval); } - context_->working_model->add_constraints()->mutable_no_overlap_2d()->Swap( + context_->AddConstraint()->mutable_no_overlap_2d()->Swap( &new_no_overlap_2d); - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("no_overlap_2d: presolved fixed rectangles"); return RemoveConstraint(ct); } @@ -6994,7 +7056,7 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { absl::Span indexes = non_fixed_components[i]; NoOverlap2DConstraintProto* new_no_overlap_2d = - context_->working_model->add_constraints()->mutable_no_overlap_2d(); + context_->AddConstraint()->mutable_no_overlap_2d(); for (const int idx : indexes) { const int b = non_fixed_boxes[idx].box_index; new_no_overlap_2d->add_x_intervals(proto.x_intervals(b)); @@ -7005,7 +7067,6 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { new_no_overlap_2d->add_y_intervals(proto.y_intervals(b)); } } - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats( "no_overlap_2d: split into disjoint components duplicating fixed " "boxes"); @@ -7022,11 +7083,13 @@ bool CpModelPresolver::PresolveNoOverlap2D(int /*c*/, ConstraintProto* ct) { } namespace { + LinearExpressionProto ConstantExpressionProto(int64_t value) { LinearExpressionProto expr; expr.set_offset(value); return expr; } + } // namespace void CpModelPresolver::DetectDuplicateIntervals( @@ -7110,7 +7173,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { "cumulative: merged demands of identical interval"); } else { LinearConstraintProto* sum_of_terms = - context_->working_model->add_constraints()->mutable_linear(); + context_->AddConstraint()->mutable_linear(); std::vector vars; vars.reserve(terms.size()); std::vector coeffs; @@ -7130,7 +7193,6 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { sum_of_terms->add_vars(variable_demand); sum_of_terms->add_coeffs(-1); FillDomainInProto(0, sum_of_terms); - context_->UpdateNewConstraintsVariableUsage(); proto->add_intervals(interval); LinearExpressionProto* demand = proto->add_demands(); @@ -7176,7 +7238,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { if (start_min > end_max) { if (context_->ConstraintIsOptional(interval_index)) { ConstraintProto* interval_ct = - context_->working_model->mutable_constraints(interval_index); + context_->MutableConstraint(interval_index); DCHECK_EQ(interval_ct->enforcement_literal_size(), 1); const int literal = interval_ct->enforcement_literal(0); if (!context_->SetLiteralToFalse(literal)) { @@ -7194,7 +7256,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { if (context_->ConstraintIsOptional(interval_index)) { if (context_->SizeMin(interval_index) > 0) { ConstraintProto* interval_ct = - context_->working_model->mutable_constraints(interval_index); + context_->MutableConstraint(interval_index); DCHECK_EQ(interval_ct->enforcement_literal_size(), 1); const int literal = interval_ct->enforcement_literal(0); if (!context_->SetLiteralToFalse(literal)) { @@ -7206,7 +7268,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { } else { // Interval performed. // Try to set the size to 0. const ConstraintProto& interval_ct = - context_->working_model->constraints(interval_index); + context_->Constraint(interval_index); if (!context_->IntersectDomainWith(interval_ct.interval().size(), {0, 0})) { return true; @@ -7285,14 +7347,13 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { if (components.size() > 1) { for (const std::vector& component : components) { CumulativeConstraintProto* new_cumulative = - context_->working_model->add_constraints()->mutable_cumulative(); + context_->AddConstraint()->mutable_cumulative(); for (const int i : component) { new_cumulative->add_intervals(proto->intervals(i)); *new_cumulative->add_demands() = proto->demands(i); } *new_cumulative->mutable_capacity() = proto->capacity(); } - context_->UpdateNewConstraintsVariableUsage(); context_->UpdateRuleStats("cumulative: split into disjoint components"); return RemoveConstraint(ct); } @@ -7391,8 +7452,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { int64_t sum_of_max_demands = 0; for (int i = 0; i < proto->intervals_size(); ++i) { const int interval_index = proto->intervals(i); - const ConstraintProto& interval_ct = - context_->working_model->constraints(interval_index); + const ConstraintProto& interval_ct = context_->Constraint(interval_index); const LinearExpressionProto& demand_expr = proto->demands(i); sum_of_max_demands += context_->MaxOf(demand_expr); @@ -7408,8 +7468,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { if (max_of_performed_demand_mins > context_->MinOf(capacity_expr)) { context_->UpdateRuleStats("cumulative: propagate min capacity"); if (!context_->IntersectDomainWith( - capacity_expr, Domain(max_of_performed_demand_mins, - std::numeric_limits::max()))) { + capacity_expr, Domain(max_of_performed_demand_mins, kint64max))) { return true; } } @@ -7464,8 +7523,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { const int index = proto->intervals(i); // TODO(user): adapt in the presence of optional intervals. if (context_->ConstraintIsOptional(index)) has_optional_interval = true; - const ConstraintProto& ct = - context_->working_model->constraints(proto->intervals(i)); + const ConstraintProto& ct = context_->Constraint(proto->intervals(i)); const IntervalConstraintProto& interval = ct.interval(); start_exprs[i] = interval.start(); @@ -7499,9 +7557,8 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { if (ct.enforcement_literal().empty()) { context_->UpdateRuleStats( "cumulative: demand_max exceeds capacity max"); - if (!context_->IntersectDomainWith( - demand_expr, - Domain(std::numeric_limits::min(), capacity_max))) { + if (!context_->IntersectDomainWith(demand_expr, + Domain(kint64min, capacity_max))) { return true; } } else { @@ -7516,7 +7573,7 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { if (num_greater_half_capacity == num_intervals) { if (num_duration_one == num_intervals && !has_optional_interval) { context_->UpdateRuleStats("cumulative: convert to all_different"); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); auto* arg = new_ct->mutable_all_diff(); for (const LinearExpressionProto& expr : start_exprs) { *arg->add_exprs() = expr; @@ -7526,14 +7583,13 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { for (const LinearExpressionProto& expr : proto->demands()) { if (capacity_min >= context_->MaxOf(expr)) continue; LinearConstraintProto* fit = - context_->working_model->add_constraints()->mutable_linear(); + context_->AddConstraint()->mutable_linear(); fit->add_domain(0); - fit->add_domain(std::numeric_limits::max()); + fit->add_domain(kint64max); AddLinearExpressionToLinearConstraint(capacity_expr, 1, fit); AddLinearExpressionToLinearConstraint(expr, -1, fit); } } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } else { context_->UpdateRuleStats("cumulative: convert to no_overlap"); @@ -7543,14 +7599,11 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { const LinearExpressionProto& demand_expr = proto->demands(i); const int64_t demand_max = context_->MaxOf(demand_expr); if (demand_max > context_->MinOf(capacity_expr)) { - ConstraintProto* capacity_gt = - context_->working_model->add_constraints(); + ConstraintProto* capacity_gt = context_->AddConstraint(); *capacity_gt->mutable_enforcement_literal() = - context_->working_model->constraints(proto->intervals(i)) - .enforcement_literal(); + context_->Constraint(proto->intervals(i)).enforcement_literal(); capacity_gt->mutable_linear()->add_domain(0); - capacity_gt->mutable_linear()->add_domain( - std::numeric_limits::max()); + capacity_gt->mutable_linear()->add_domain(kint64max); AddLinearExpressionToLinearConstraint(capacity_expr, 1, capacity_gt->mutable_linear()); AddLinearExpressionToLinearConstraint(demand_expr, -1, @@ -7558,12 +7611,11 @@ bool CpModelPresolver::PresolveCumulative(ConstraintProto* ct) { } } - ConstraintProto* new_ct = context_->working_model->add_constraints(); + ConstraintProto* new_ct = context_->AddConstraint(); auto* arg = new_ct->mutable_no_overlap(); for (const int interval : proto->intervals()) { arg->add_intervals(interval); } - context_->UpdateNewConstraintsVariableUsage(); return RemoveConstraint(ct); } } @@ -7991,7 +8043,7 @@ bool CpModelPresolver::PresolveReservoir(ConstraintProto* ct) { (num_positives == 0 || num_negatives == 0)) { // If all level_changes have the same sign, and if the initial state is // always feasible, we do not care about the order, just the sum. - auto* const sum_ct = context_->working_model->add_constraints(); + auto* const sum_ct = context_->AddConstraint(); auto* const sum = sum_ct->mutable_linear(); int64_t fixed_contrib = 0; for (int i = 0; i < proto.level_changes_size(); ++i) { @@ -8045,8 +8097,7 @@ bool CpModelPresolver::PresolveReservoir(ConstraintProto* ct) { absl::flat_hash_set> time_active_set; for (int i = 0; i < proto.level_changes_size(); ++i) { const LinearExpressionProto& time = proto.time_exprs(i); - const int var = context_->IsFixed(time) ? std::numeric_limits::min() - : time.vars(0); + const int var = context_->IsFixed(time) ? kint32min : time.vars(0); const int64_t coeff = context_->IsFixed(time) ? 0 : time.coeffs(0); const std::tuple key = std::make_tuple( var, coeff, @@ -8069,37 +8120,38 @@ bool CpModelPresolver::PresolveReservoir(ConstraintProto* ct) { // rather than modifying more than once the proto. void CpModelPresolver::ConvertToBoolAnd() { absl::flat_hash_map ref_to_bool_and; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); std::vector to_remove; for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (HasEnforcementLiteral(ct)) continue; if (ct.constraint_case() == ConstraintProto::kBoolOr && ct.bool_or().literals().size() == 2) { - AddImplication(NegatedRef(ct.bool_or().literals(0)), - ct.bool_or().literals(1), context_->working_model, - &ref_to_bool_and); + AddImplicationWithMerging(NegatedRef(ct.bool_or().literals(0)), + ct.bool_or().literals(1), context_, + &ref_to_bool_and); to_remove.push_back(c); continue; } if (ct.constraint_case() == ConstraintProto::kAtMostOne && ct.at_most_one().literals().size() == 2) { - AddImplication(ct.at_most_one().literals(0), - NegatedRef(ct.at_most_one().literals(1)), - context_->working_model, &ref_to_bool_and); + AddImplicationWithMerging(ct.at_most_one().literals(0), + NegatedRef(ct.at_most_one().literals(1)), + context_, &ref_to_bool_and); to_remove.push_back(c); continue; } } - context_->UpdateNewConstraintsVariableUsage(); + DCHECK(context_->ConstraintVariableUsageIsConsistent()); for (const int c : to_remove) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); CHECK(RemoveConstraint(ct)); context_->UpdateConstraintVariableUsage(c); } + DCHECK(context_->ConstraintVariableUsageIsConsistent()); } void CpModelPresolver::RunPropagatorsForConstraint(const ConstraintProto& ct) { @@ -8229,7 +8281,7 @@ void CpModelPresolver::Probe() { std::vector to_update; for (const int c : context_->VarToConstraints(decision_var)) { if (c < 0) continue; - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.enforcement_literal().size() > 2) { // Any l for which decision => l can be removed. // @@ -8268,7 +8320,7 @@ void CpModelPresolver::Probe() { if (has_false_literal) { // Reduce to implication. - auto* mutable_ct = context_->working_model->mutable_constraints(c); + auto* mutable_ct = context_->MutableConstraint(c); mutable_ct->Clear(); mutable_ct->add_enforcement_literal(decision_ref); mutable_ct->mutable_bool_and()->add_literals(NegatedRef(false_ref)); @@ -8281,8 +8333,7 @@ void CpModelPresolver::Probe() { if (simplification_possible) { int new_size = 0; auto* mutable_enforcements = - context_->working_model->mutable_constraints(c) - ->mutable_enforcement_literal(); + context_->MutableConstraint(c)->mutable_enforcement_literal(); for (const int ref : ct.enforcement_literal()) { if (PositiveRef(ref) != decision_var && assignment.LiteralIsTrue(mapping->Literal(ref))) { @@ -8328,7 +8379,7 @@ void CpModelPresolver::Probe() { // This will later be merged with the current implications and removed // if it is a duplicate. auto* mutable_bool_or = - context_->working_model->mutable_constraints(c)->mutable_bool_or(); + context_->MutableConstraint(c)->mutable_bool_or(); mutable_bool_or->mutable_literals()->Clear(); mutable_bool_or->add_literals(decision_ref); mutable_bool_or->add_literals(true_ref); @@ -8340,7 +8391,7 @@ void CpModelPresolver::Probe() { if (simplification_possible) { int new_size = 0; auto* mutable_bool_or = - context_->working_model->mutable_constraints(c)->mutable_bool_or(); + context_->MutableConstraint(c)->mutable_bool_or(); for (const int ref : ct.bool_or().literals()) { if (PositiveRef(ref) != decision_var && assignment.LiteralIsFalse(mapping->Literal(ref))) { @@ -8411,7 +8462,7 @@ void CpModelPresolver::Probe() { int num_equiv = 0; int num_changed_bounds = 0; - const int num_variables = context_->working_model->variables().size(); + const int num_variables = context_->NumVariables(); auto* integer_trail = model.GetOrCreate(); for (int var = 0; var < num_variables; ++var) { // Restrict IntegerVariable domain. @@ -8464,9 +8515,9 @@ void CpModelPresolver::Probe() { // TODO(user): On large model, most of the time is spend in this copy, // clearing and updating the constraint variable graph... int64_t num_literals_before = 0; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->constraint_case() == ConstraintProto::kAtMostOne) { std::vector clique; for (const int ref : ct->at_most_one().literals()) { @@ -8513,7 +8564,7 @@ void CpModelPresolver::Probe() { if (clique.empty()) continue; num_new_cliques++; num_literals_after += clique.size(); - ConstraintProto* ct = context_->working_model->add_constraints(); + ConstraintProto* ct = context_->AddConstraint(); for (const Literal literal : clique) { const int var = mapping->GetProtoVariableFromBooleanVariable(literal.Variable()); @@ -8526,9 +8577,11 @@ void CpModelPresolver::Probe() { } // Make sure we do not have duplicate variable reference. - PresolveAtMostOne(ct); + // + // Tricky: note that it is important to not use dual reduction here as not + // all constraints are in the proto during the loop. + PresolveAtMostOne(ct, /*use_dual_reduction=*/false); } - context_->UpdateNewConstraintsVariableUsage(); if (num_new_cliques != num_old_cliques) { context_->UpdateRuleStats("at_most_one: transformed into max clique"); } @@ -8542,6 +8595,7 @@ void CpModelPresolver::Probe() { " with ", Plural(num_literals_after, "literal"))); } } + DCHECK(context_->ConstraintVariableUsageIsConsistent()); } namespace { @@ -8575,7 +8629,7 @@ bool CpModelPresolver::PresolvePureSatPart() { // Compute a dense re-indexing for the Booleans of the problem. int num_variables = 0; int num_ignored_variables = 0; - const int total_num_vars = context_->working_model->variables().size(); + const int total_num_vars = context_->NumVariables(); std::vector new_index(total_num_vars, -1); std::vector new_to_old_index; for (int i = 0; i < total_num_vars; ++i) { @@ -8614,6 +8668,7 @@ bool CpModelPresolver::PresolvePureSatPart() { local_model.GetOrCreate()->MergeWithGlobalTimeLimit(time_limit_); auto* sat_solver = local_model.GetOrCreate(); auto* graph = local_model.GetOrCreate(); + *local_model.GetOrCreate() = context_->params(); sat_solver->SetNumVariables(num_variables); // Fix variables if any. Because we might not have reached the presove "fixed @@ -8635,8 +8690,8 @@ bool CpModelPresolver::PresolvePureSatPart() { int num_removed_constraints = 0; int num_ignored_constraints = 0; const bool load_amo = context_->params().load_at_most_ones_in_sat_presolve(); - for (int i = 0; i < context_->working_model->constraints_size(); ++i) { - const ConstraintProto& ct = context_->working_model->constraints(i); + for (int i = 0; i < context_->NumConstraints(); ++i) { + const ConstraintProto& ct = context_->Constraint(i); if (ct.constraint_case() == ConstraintProto::kBoolOr) { ++num_removed_constraints; @@ -8649,7 +8704,7 @@ bool CpModelPresolver::PresolvePureSatPart() { } sat_solver->AddProblemClause(clause); - context_->working_model->mutable_constraints(i)->Clear(); + context_->ClearConstraint(i); context_->UpdateConstraintVariableUsage(i); continue; } @@ -8665,7 +8720,7 @@ bool CpModelPresolver::PresolvePureSatPart() { if (!graph->AddAtMostOne(clause)) return false; ++num_removed_constraints; - context_->working_model->mutable_constraints(i)->Clear(); + context_->ClearConstraint(i); context_->UpdateConstraintVariableUsage(i); continue; } @@ -8682,7 +8737,7 @@ bool CpModelPresolver::PresolvePureSatPart() { sat_solver->AddProblemClause(clause); ++num_removed_constraints; - context_->working_model->mutable_constraints(i)->Clear(); + context_->ClearConstraint(i); context_->UpdateConstraintVariableUsage(i); continue; } @@ -8709,7 +8764,7 @@ bool CpModelPresolver::PresolvePureSatPart() { sat_solver->AddProblemClause(clause); } - context_->working_model->mutable_constraints(i)->Clear(); + context_->ClearConstraint(i); context_->UpdateConstraintVariableUsage(i); continue; } @@ -8847,7 +8902,7 @@ bool CpModelPresolver::PresolvePureSatPart() { if (amo.empty()) break; // Re-add the amo to the proto. - ConstraintProto* ct = context_->working_model->add_constraints(); + ConstraintProto* ct = context_->AddConstraint(); ct->mutable_at_most_one()->mutable_literals()->Reserve(amo.size()); for (Literal l : amo) { // TODO(user): ProbeAndFindEquivalentLiteral() do not register newly @@ -8865,7 +8920,6 @@ bool CpModelPresolver::PresolvePureSatPart() { can_be_removed[l.Variable().value()] = false; } } - context_->UpdateNewConstraintsVariableUsage(); // Apply the "old" SAT presolve. SatPresolver sat_presolver(&sat_postsolver, logger_); @@ -8896,12 +8950,12 @@ bool CpModelPresolver::PresolvePureSatPart() { if (new_num_variables > num_variables) { VLOG(1) << "New variables added by the SAT presolver."; for (int i = num_variables; i < new_num_variables; ++i) { - new_to_old_index.push_back(context_->working_model->variables().size()); - IntegerVariableProto* var_proto = - context_->working_model->add_variables(); - var_proto->add_domain(0); - var_proto->add_domain(1); + new_to_old_index.push_back(context_->NumVariables()); + context_->NewBoolVar("BVA"); } + + // TODO(user): NewBoolVar() already call this each time. + // Provide a batch interface? or we don't care. context_->InitializeNewDomains(); } @@ -8912,11 +8966,7 @@ bool CpModelPresolver::PresolvePureSatPart() { } // Add the presolver clauses back into the model. - ExtractClauses(/*merge_into_bool_and=*/true, new_to_old_index, sat_presolver, - context_->working_model); - - // Update the constraints <-> variables graph. - context_->UpdateNewConstraintsVariableUsage(); + ExtractClausesToContext(new_to_old_index, sat_presolver, context_); // We mark as removed any variables removed by the pure SAT presolve. // This is mainly to discover or avoid bug as we might have stale entries @@ -8929,10 +8979,8 @@ bool CpModelPresolver::PresolvePureSatPart() { } // Add the sat_postsolver clauses to mapping_model. - const std::string name = - absl::GetFlag(FLAGS_cp_model_debug_postsolve) ? "sat_postsolver" : ""; - ExtractClauses(/*merge_into_bool_and=*/false, new_to_old_index, - sat_postsolver, context_->mapping_model, name); + ExtractClausesToMappingModelProto(new_to_old_index, sat_postsolver, + context_->mapping_model); return true; } @@ -8956,13 +9004,13 @@ bool CpModelPresolver::PresolvePureSatProblem() { if (context_->params().keep_all_feasible_solutions_in_presolve()) return true; Model local_model; - LratProofHandler* lrat_proof_handler = context_->lrat_proof_handler.get(); + LratProofHandler* lrat_proof_handler = context_->lrat_proof_handler; if (lrat_proof_handler != nullptr) { local_model.Register(lrat_proof_handler); } local_model.GetOrCreate()->MergeWithGlobalTimeLimit(time_limit_); auto* sat_solver = local_model.GetOrCreate(); - const int num_variables = context_->working_model->variables().size(); + const int num_variables = context_->NumVariables(); sat_solver->SetNumVariables(num_variables); std::vector new_to_old_index; @@ -8976,8 +9024,8 @@ bool CpModelPresolver::PresolvePureSatProblem() { return Literal(BooleanVariable(PositiveRef(ref)), RefIsPositive(ref)); }; std::vector clause; - for (int i = 0; i < context_->working_model->constraints_size(); ++i) { - const ConstraintProto& ct = context_->working_model->constraints(i); + for (int i = 0; i < context_->NumConstraints(); ++i) { + const ConstraintProto& ct = context_->Constraint(i); CHECK_EQ(ct.constraint_case(), ConstraintProto::kBoolOr); clause.clear(); for (const int ref : ct.enforcement_literal()) { @@ -8987,7 +9035,7 @@ bool CpModelPresolver::PresolvePureSatProblem() { clause.push_back(ref_to_literal(ref)); } sat_solver->AddProblemClause(clause, /*one_based_cnf_index=*/i + 1); - context_->working_model->mutable_constraints(i)->Clear(); + context_->ClearConstraint(i); context_->UpdateConstraintVariableUsage(i); } if (sat_solver->ModelIsUnsat()) return false; @@ -9041,19 +9089,13 @@ bool CpModelPresolver::PresolvePureSatProblem() { context_)) { return false; } - // TODO(user): can we improve ExtractClauses() to avoid the intermediate - // container? + // TODO(user): can we improve ExtractClausesToContext() to avoid the + // intermediate container? BasicClauseContainer clauses_container; if (!sat_solver->ExtractClauses(&clauses_container)) return false; - ExtractClauses(/*merge_into_bool_and=*/true, new_to_old_index, - clauses_container, context_->working_model); - const std::string name = - absl::GetFlag(FLAGS_cp_model_debug_postsolve) ? "sat_postsolver" : ""; - ExtractClauses(/*merge_into_bool_and=*/false, new_to_old_index, - sat_postsolver, context_->mapping_model, name); - - // Update the constraints <-> variables graph. - context_->UpdateNewConstraintsVariableUsage(); + ExtractClausesToContext(new_to_old_index, clauses_container, context_); + ExtractClausesToMappingModelProto(new_to_old_index, sat_postsolver, + context_->mapping_model); // We mark as removed any variables removed by the pure SAT presolve. // This is mainly to discover or avoid bug as we might have stale entries @@ -9076,9 +9118,9 @@ void CpModelPresolver::ShiftObjectiveWithExactlyOnes() { } std::vector exos; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (!ct.enforcement_literal().empty()) continue; if (ct.constraint_case() == ConstraintProto::kExactlyOne) { exos.push_back(c); @@ -9099,11 +9141,11 @@ void CpModelPresolver::ShiftObjectiveWithExactlyOnes() { int num_shifts = 0; for (int i = 0; i < 3; ++i) { for (const int c : exos) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); const int num_terms = ct.exactly_one().literals().size(); if (num_terms <= 1) continue; - int64_t min_obj = std::numeric_limits::max(); - int64_t second_min = std::numeric_limits::max(); + int64_t min_obj = kint64max; + int64_t second_min = kint64max; for (int i = 0; i < num_terms; ++i) { const int literal = ct.exactly_one().literals(i); const int64_t var_obj = context_->ObjectiveCoeff(PositiveRef(literal)); @@ -9131,13 +9173,13 @@ void CpModelPresolver::ShiftObjectiveWithExactlyOnes() { } bool CpModelPresolver::PropagateObjective() { - if (!context_->working_model->has_objective()) return true; + if (!context_->WorkingModel().has_objective()) return true; if (context_->ModelIsUnsat()) return false; context_->WriteObjectiveToProto(); int64_t min_activity = 0; int64_t max_variation = 0; - const CpObjectiveProto& objective = context_->working_model->objective(); + const CpObjectiveProto& objective = context_->WorkingModel().objective(); const int num_terms = objective.vars().size(); for (int i = 0; i < num_terms; ++i) { const int var = objective.vars(i); @@ -9222,8 +9264,8 @@ void CpModelPresolver::ExpandObjective() { return; } - const int num_variables = context_->working_model->variables_size(); - const int num_constraints = context_->working_model->constraints_size(); + const int num_variables = context_->NumVariables(); + const int num_constraints = context_->NumConstraints(); // We consider two types of shifted variables (X - LB(X)) and (UB(X) - X). const auto get_index = [](int var, bool to_lb) { @@ -9250,7 +9292,7 @@ void CpModelPresolver::ExpandObjective() { for (int c = 0; c < num_constraints; ++c) { if (num_entries > kNumEntriesThreshold) break; - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (!ct.enforcement_literal().empty()) continue; // Deal with exactly one. @@ -9260,7 +9302,7 @@ void CpModelPresolver::ExpandObjective() { // one. We thus make sure the constraint is re-presolved if for some reason // we didn't reach the fixed point before calling this code. if (ct.constraint_case() == ConstraintProto::kExactlyOne) { - if (PresolveExactlyOne(context_->working_model->mutable_constraints(c))) { + if (PresolveExactlyOne(context_->MutableConstraint(c))) { context_->UpdateConstraintVariableUsage(c); } } @@ -9430,8 +9472,7 @@ void CpModelPresolver::ExpandObjective() { const bool to_lb = (index % 2) == 0; if (obj_coeff > 0 == to_lb) { - const ConstraintProto& ct = - context_->working_model->constraints(index_to_best_c[index]); + const ConstraintProto& ct = context_->Constraint(index_to_best_c[index]); if (ct.constraint_case() == ConstraintProto::kExactlyOne) { int64_t shift = 0; for (const int lit : ct.exactly_one().literals()) { @@ -9496,7 +9537,7 @@ bool CpModelPresolver::MergeCliqueConstraintsHelper( std::vector>& cliques, std::string_view entry_name, PresolveTimer& timer) { if (cliques.empty()) return false; // Nothing has changed. - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); int old_num_clique_constraints = cliques.size(); int old_num_entries = 0; for (const std::vector& clique : cliques) { @@ -9514,11 +9555,12 @@ bool CpModelPresolver::MergeCliqueConstraintsHelper( // so this should never return UNSAT. CHECK(graph->AddAtMostOne(clique)); } + + // We shouldn't be UNSAT here. CHECK(graph->DetectEquivalences()); - graph->TransformIntoMaxCliques( + CHECK(graph->TransformIntoMaxCliques( &cliques, - SafeDoubleToInt64(context_->params().merge_no_overlap_work_limit())); - + SafeDoubleToInt64(context_->params().merge_no_overlap_work_limit()))); time_limit_->ResetHistory(); // Update the number of constraints and entries after the max-clique. @@ -9548,14 +9590,14 @@ bool CpModelPresolver::MergeNoOverlapConstraints() { if (context_->ModelIsUnsat()) return false; if (time_limit_->LimitReached()) return true; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); // Extract the no-overlap constraints with no enforcement literals. // TODO(user): generalize this to merge constraints with the same // enforcement literals? std::vector disjunctive_index; std::vector> cliques; for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() != ConstraintProto::kNoOverlap) continue; if (HasEnforcementLiteral(ct)) continue; std::vector clique; @@ -9573,21 +9615,19 @@ bool CpModelPresolver::MergeNoOverlapConstraints() { // Remove previous no_overlap constraints and add the new recomputed ones. for (int i = 0; i < cliques.size(); ++i) { const int ct_index = disjunctive_index[i]; - if (RemoveConstraint( - context_->working_model->mutable_constraints(ct_index))) { + if (RemoveConstraint(context_->MutableConstraint(ct_index))) { context_->UpdateConstraintVariableUsage(ct_index); } } for (int i = 0; i < cliques.size(); ++i) { if (cliques[i].empty()) continue; - ConstraintProto* ct = context_->working_model->add_constraints(); + ConstraintProto* ct = context_->AddConstraint(); for (const Literal l : cliques[i]) { CHECK(l.IsPositive()); ct->mutable_no_overlap()->add_intervals(l.Variable().value()); } } context_->UpdateRuleStats("no_overlap: merged constraints"); - context_->UpdateNewConstraintsVariableUsage(); return true; } @@ -9596,7 +9636,7 @@ bool CpModelPresolver::MergeNoOverlap2DConstraints() { if (context_->ModelIsUnsat()) return false; if (time_limit_->LimitReached()) return true; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); // Extract the no-overlap constraints with no enforcement literals. // TODO(user): generalize this to merge constraints with the same // enforcement literals? @@ -9605,7 +9645,7 @@ bool CpModelPresolver::MergeNoOverlap2DConstraints() { absl::flat_hash_map, int> rectangle_to_index; std::vector> index_to_rectangle; for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() != ConstraintProto::kNoOverlap2D) continue; if (HasEnforcementLiteral(ct)) continue; std::vector clique; @@ -9628,14 +9668,13 @@ bool CpModelPresolver::MergeNoOverlap2DConstraints() { // Remove previous no_overlap constraints and add the new recomputed ones. for (int i = 0; i < cliques.size(); ++i) { const int ct_index = no_overlap2d_index[i]; - if (RemoveConstraint( - context_->working_model->mutable_constraints(ct_index))) { + if (RemoveConstraint(context_->MutableConstraint(ct_index))) { context_->UpdateConstraintVariableUsage(ct_index); } } for (int i = 0; i < cliques.size(); ++i) { if (cliques[i].empty()) continue; - ConstraintProto* ct = context_->working_model->add_constraints(); + ConstraintProto* ct = context_->AddConstraint(); for (const Literal l : cliques[i]) { CHECK(l.IsPositive()); const std::pair rect = index_to_rectangle[l.Variable().value()]; @@ -9644,7 +9683,6 @@ bool CpModelPresolver::MergeNoOverlap2DConstraints() { } } context_->UpdateRuleStats("no_overlap_2d: merged constraints"); - context_->UpdateNewConstraintsVariableUsage(); return true; } @@ -9676,14 +9714,14 @@ void CpModelPresolver::TransformIntoMaxCliques() { if (RefIsPositive(ref)) return Literal(BooleanVariable(ref), true); return Literal(BooleanVariable(NegatedRef(ref)), false); }; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); // Extract the bool_and and at_most_one constraints. // TODO(user): use probing info? std::vector> cliques; for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->constraint_case() == ConstraintProto::kAtMostOne) { std::vector clique; for (const int ref : ct->at_most_one().literals()) { @@ -9711,7 +9749,7 @@ void CpModelPresolver::TransformIntoMaxCliques() { // We reuse the max-clique code from sat. Model local_model; - const int num_variables = context_->working_model->variables().size(); + const int num_variables = context_->NumVariables(); local_model.GetOrCreate()->Resize(num_variables); auto* graph = local_model.GetOrCreate(); graph->Resize(num_variables); @@ -9749,7 +9787,7 @@ void CpModelPresolver::TransformIntoMaxCliques() { if (clique.empty()) continue; num_new_cliques++; num_literals_after += clique.size(); - ConstraintProto* ct = context_->working_model->add_constraints(); + ConstraintProto* ct = context_->AddConstraint(); for (const Literal literal : clique) { if (literal.IsPositive()) { ct->mutable_at_most_one()->add_literals(literal.Variable().value()); @@ -9760,9 +9798,11 @@ void CpModelPresolver::TransformIntoMaxCliques() { } // Make sure we do not have duplicate variable reference. - PresolveAtMostOne(ct); + // + // Tricky: note that it is important to not use dual reduction here as not + // all constraints are in the proto during the loop. + PresolveAtMostOne(ct, /*use_dual_reduction=*/false); } - context_->UpdateNewConstraintsVariableUsage(); if (num_new_cliques != num_old_cliques) { context_->UpdateRuleStats("at_most_one: transformed into max clique"); } @@ -9775,6 +9815,115 @@ void CpModelPresolver::TransformIntoMaxCliques() { } } +void CpModelPresolver::SplitNoOverlapAndCumulativeConstraints() { + if (time_limit_->LimitReached()) return; + if (context_->ModelIsUnsat()) return; + PresolveTimer timer(__FUNCTION__, logger_, time_limit_); + std::vector all_no_overlap_intervals; + std::vector all_no_overlap_or_cumulative_constraints; + for (int c = 0; c < context_->NumConstraints(); ++c) { + const ConstraintProto& ct = context_->Constraint(c); + if (ct.constraint_case() == ConstraintProto::kNoOverlap || + ct.constraint_case() == ConstraintProto::kCumulative) { + all_no_overlap_or_cumulative_constraints.push_back(c); + const google::protobuf::RepeatedField& indices = + (ct.constraint_case() == ConstraintProto::kNoOverlap) + ? ct.no_overlap().intervals() + : ct.cumulative().intervals(); + + for (const int interval : indices) { + all_no_overlap_intervals.push_back(interval); + } + } + } + + if (all_no_overlap_intervals.empty()) return; + + int num_split_constraints = 0; + gtl::STLSortAndRemoveDuplicates(&all_no_overlap_intervals); + const std::vector> precedences = + DetectIntervalPrecedences(context_->WorkingModel(), known_model_linear2_, + all_no_overlap_intervals); + std::vector intervals; + for (const int c : all_no_overlap_or_cumulative_constraints) { + intervals.clear(); + const ConstraintProto& ct = context_->Constraint(c); + const bool is_no_overlap = + (ct.constraint_case() == ConstraintProto::kNoOverlap); + + const google::protobuf::RepeatedField& interval_indices = + is_no_overlap ? ct.no_overlap().intervals() + : ct.cumulative().intervals(); + intervals.reserve(interval_indices.size()); + for (const int interval : interval_indices) { + const ConstraintProto& interval_ct = context_->Constraint(interval); + intervals.push_back(IndexedInterval{ + .index = interval, + .start = context_->MinOf(interval_ct.interval().start()), + .end = context_->MaxOf(interval_ct.interval().end()), + }); + } + const auto components = + IntervalsNonOverlappingComponents(intervals, precedences); + + if (components.size() == 1 && components[0].size() == intervals.size()) { + continue; + } + + const ConstraintProto orig_ct = ct; + + absl::flat_hash_map interval_to_orig_index; + if (!is_no_overlap) { + interval_to_orig_index.reserve(interval_indices.size()); + for (int i = 0; i < interval_indices.size(); ++i) { + auto [it, inserted] = + interval_to_orig_index.insert({interval_indices[i], i}); + if (!inserted) { + context_->UpdateRuleStats( + "TODO: ignored duplicate interval in " + "cumulative constraint"); + return; + } + } + } + + context_->UpdateRuleStats( + absl::StrCat((is_no_overlap ? "no_overlap" : "cumulative"), + ": split using precedences")); + ++num_split_constraints; + + if (!RemoveConstraint(context_->MutableConstraint(c))) { + return; + } + context_->UpdateConstraintVariableUsage(c); + for (const auto& component : components.AsVectorOfSpan()) { + if (is_no_overlap && component.size() <= 1) { + continue; + } + ConstraintProto* new_ct = context_->AddConstraint(); + *new_ct->mutable_enforcement_literal() = orig_ct.enforcement_literal(); + if (is_no_overlap) { + for (const int interval : component) { + new_ct->mutable_no_overlap()->mutable_intervals()->Add(interval); + } + } else { + *new_ct->mutable_cumulative()->mutable_capacity() = + orig_ct.cumulative().capacity(); + + for (const int interval : component) { + new_ct->mutable_cumulative()->mutable_intervals()->Add(interval); + + const auto it = interval_to_orig_index.find(interval); + DCHECK(it != interval_to_orig_index.end()); + *new_ct->mutable_cumulative()->add_demands() = + orig_ct.cumulative().demands(it->second); + } + } + } + } + timer.AddCounter("num_split_constraints", num_split_constraints); +} + void CpModelPresolver::TransformClausesToExactlyOne() { if (context_->ModelIsUnsat()) return; if (!context_->params().find_clauses_that_are_exactly_one()) return; @@ -9784,11 +9933,11 @@ void CpModelPresolver::TransformClausesToExactlyOne() { if (RefIsPositive(ref)) return Literal(BooleanVariable(ref), true); return Literal(BooleanVariable(NegatedRef(ref)), false); }; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); // We reuse the BinaryImplicationGraph code to "propagate" 2-SAT. Model local_model; - const int num_variables = context_->working_model->variables().size(); + const int num_variables = context_->NumVariables(); local_model.GetOrCreate()->Resize(num_variables); auto* graph = local_model.GetOrCreate(); graph->Resize(num_variables); @@ -9800,7 +9949,7 @@ void CpModelPresolver::TransformClausesToExactlyOne() { std::vector clause_indices; std::vector> clauses; for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->constraint_case() == ConstraintProto::kAtMostOne) { tmp_clique.clear(); for (const int ref : ct->at_most_one().literals()) { @@ -9917,10 +10066,8 @@ void CpModelPresolver::TransformClausesToExactlyOne() { ++num_transformed; context_->UpdateRuleStats("clauses: transformed into exactly one"); google::protobuf::RepeatedField tmp = - context_->working_model->constraints(clause_indices[i]) - .bool_or() - .literals(); - *(context_->working_model->mutable_constraints(clause_indices[i]) + context_->Constraint(clause_indices[i]).bool_or().literals(); + *(context_->MutableConstraint(clause_indices[i]) ->mutable_exactly_one() ->mutable_literals()) = tmp; } @@ -9935,7 +10082,7 @@ void CpModelPresolver::TransformClausesToExactlyOne() { bool CpModelPresolver::PresolveOneConstraint(int c) { if (context_->ModelIsUnsat()) return false; - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); // Generic presolve to exploit variable/literal equivalence. if (ExploitEquivalenceRelations(c, ct)) { @@ -9994,25 +10141,34 @@ bool CpModelPresolver::PresolveOneConstraint(int c) { if (PropagateDomainsInLinear(c, ct)) { context_->UpdateConstraintVariableUsage(c); } - if (PresolveSmallLinear(ct)) { + + // The constraint should already be canonicalized at this stage. + if (PresolveSmallLinear(ct, /*canonicalize=*/false)) { context_->UpdateConstraintVariableUsage(c); } - if (PresolveLinearEqualityWithModulo(ct)) { - context_->UpdateConstraintVariableUsage(c); + + bool redo_small_linear_presolve = false; + if (IsLinearEqualityConstraint(*ct)) { + redo_small_linear_presolve = true; + if (PresolveLinearEqualityWithModulo(ct)) { + context_->UpdateConstraintVariableUsage(c); + } } + // We first propagate the domains before calling this presolve rule. if (RemoveSingletonInLinear(ct)) { context_->UpdateConstraintVariableUsage(c); + redo_small_linear_presolve = true; + } + if (redo_small_linear_presolve) { // There is no need to re-do a propagation here, but the constraint // size might have been reduced. if (PresolveSmallLinear(ct)) { context_->UpdateConstraintVariableUsage(c); } } - if (PresolveSmallLinear(ct)) { - context_->UpdateConstraintVariableUsage(c); - } + if (PresolveLinearOnBooleans(ct)) { context_->UpdateConstraintVariableUsage(c); } @@ -10040,6 +10196,8 @@ bool CpModelPresolver::PresolveOneConstraint(int c) { case ConstraintProto::kInterval: return PresolveInterval(c, ct); case ConstraintProto::kInverse: + DCHECK(ct->inverse().f_direct().empty() && + ct->inverse().f_inverse().empty()); return PresolveInverse(ct); case ConstraintProto::kElement: return PresolveElement(c, ct); @@ -10088,10 +10246,8 @@ bool CpModelPresolver::ProcessSetPPCSubset(int subset_c, int superset_c, bool* remove_subset, bool* remove_superset, bool* stop_processing_superset) { - ConstraintProto* subset_ct = - context_->working_model->mutable_constraints(subset_c); - ConstraintProto* superset_ct = - context_->working_model->mutable_constraints(superset_c); + ConstraintProto* subset_ct = context_->MutableConstraint(subset_c); + ConstraintProto* superset_ct = context_->MutableConstraint(superset_c); if ((subset_ct->constraint_case() == ConstraintProto::kBoolOr || subset_ct->constraint_case() == ConstraintProto::kExactlyOne) && @@ -10151,8 +10307,8 @@ bool CpModelPresolver::ProcessSetPPCSubset(int subset_c, int superset_c, if (subset_ct->constraint_case() == ConstraintProto::kExactlyOne && superset_ct->constraint_case() == ConstraintProto::kLinear) { tmp_set->clear(); - int64_t min_sum = std::numeric_limits::max(); - int64_t max_sum = std::numeric_limits::min(); + int64_t min_sum = kint64max; + int64_t max_sum = kint64min; tmp_set->insert(subset_ct->exactly_one().literals().begin(), subset_ct->exactly_one().literals().end()); @@ -10259,7 +10415,7 @@ bool CpModelPresolver::ProcessSetPPCSubset(int subset_c, int superset_c, new_ct.mutable_coeffs()->Truncate(new_size); FillDomainInProto(ReadDomainFromProto(new_ct).AdditionWith(Domain(-best)), &new_ct); - if (!PossibleIntegerOverflow(*context_->working_model, new_ct.vars(), + if (!PossibleIntegerOverflow(context_->WorkingModel(), new_ct.vars(), new_ct.coeffs())) { *superset_ct->mutable_linear() = std::move(new_ct); context_->UpdateConstraintVariableUsage(superset_c); @@ -10295,10 +10451,10 @@ void CpModelPresolver::ProcessSetPPC() { // We use an encoding of literal that allows to index arrays. std::vector temp_literals; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); std::vector relevant_constraints; for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); const auto type = ct->constraint_case(); if (type == ConstraintProto::kBoolOr || type == ConstraintProto::kAtMostOne || @@ -10357,6 +10513,7 @@ void CpModelPresolver::ProcessSetPPC() { absl::flat_hash_set tmp_set; int64_t num_inclusions = 0; + temp_ct_.Clear(); detector.DetectInclusions([&](int subset, int superset) { ++num_inclusions; bool remove_subset = false; @@ -10372,12 +10529,12 @@ void CpModelPresolver::ProcessSetPPC() { return; } if (remove_subset) { - context_->working_model->mutable_constraints(subset_c)->Clear(); + context_->ClearConstraint(subset_c); context_->UpdateConstraintVariableUsage(subset_c); detector.StopProcessingCurrentSubset(); } if (remove_superset) { - context_->working_model->mutable_constraints(superset_c)->Clear(); + context_->ClearConstraint(superset_c); context_->UpdateConstraintVariableUsage(superset_c); detector.StopProcessingCurrentSuperset(); } @@ -10405,9 +10562,9 @@ void CpModelPresolver::DetectIncludedEnforcement() { detector.SetWorkLimit(context_->params().presolve_inclusion_work_limit()); std::vector temp_literals; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->enforcement_literal().size() <= 1) continue; // Make sure there is no x => x. @@ -10442,10 +10599,8 @@ void CpModelPresolver::DetectIncludedEnforcement() { ++num_inclusions; const int subset_c = relevant_constraints[subset]; const int superset_c = relevant_constraints[superset]; - ConstraintProto* subset_ct = - context_->working_model->mutable_constraints(subset_c); - ConstraintProto* superset_ct = - context_->working_model->mutable_constraints(superset_c); + ConstraintProto* subset_ct = context_->MutableConstraint(subset_c); + ConstraintProto* superset_ct = context_->MutableConstraint(superset_c); if (subset_ct->constraint_case() != ConstraintProto::kBoolAnd) return; context_->tmp_literal_set.clear(); @@ -10538,9 +10693,9 @@ bool CpModelPresolver::ProcessEncodingFromLinear( // Preprocess the linear constraints. const ConstraintProto& linear_encoding = - context_->working_model->constraints(linear_encoding_ct_index); + context_->Constraint(linear_encoding_ct_index); int64_t rhs = linear_encoding.linear().domain(0); - int target_ref = std::numeric_limits::min(); + int target_ref = kint32min; std::vector> ref_to_coeffs; const int num_terms = linear_encoding.linear().vars().size(); for (int i = 0; i < num_terms; ++i) { @@ -10549,7 +10704,7 @@ bool CpModelPresolver::ProcessEncodingFromLinear( const auto it = var_to_ref.find(PositiveRef(ref)); if (it == var_to_ref.end()) { - CHECK_EQ(target_ref, std::numeric_limits::min()) << "Uniqueness"; + CHECK_EQ(target_ref, kint32min) << "Uniqueness"; CHECK_EQ(std::abs(coeff), 1); target_ref = coeff == 1 ? ref : NegatedRef(ref); continue; @@ -10566,8 +10721,7 @@ bool CpModelPresolver::ProcessEncodingFromLinear( ref_to_coeffs.push_back({NegatedRef(ref), -coeff}); } } - if (target_ref == std::numeric_limits::min() || - context_->CanBeUsedAsLiteral(target_ref)) { + if (target_ref == kint32min || context_->CanBeUsedAsLiteral(target_ref)) { // We didn't find the unique integer variable. This might have happenned // because by processing other encoding we might end up with a fully boolean // constraint. Just abort, it will be presolved later. @@ -10654,8 +10808,7 @@ bool CpModelPresolver::ProcessEncodingFromLinear( // exactly one! Experiment with this. In particular it might capture // more structure for later heuristic to add the exactly one instead. // This also applies to automata/table/element expansion. - auto* bool_or = - context_->working_model->add_constraints()->mutable_bool_or(); + auto* bool_or = context_->AddConstraint()->mutable_bool_or(); for (const int lit : literals) bool_or->add_literals(lit); bool_or->add_literals(NegatedRef(associated_lit)); } @@ -10663,9 +10816,7 @@ bool CpModelPresolver::ProcessEncodingFromLinear( } // Remove linear constraint now that it is fully encoded. - context_->working_model->mutable_constraints(linear_encoding_ct_index) - ->Clear(); - context_->UpdateNewConstraintsVariableUsage(); + context_->ClearConstraint(linear_encoding_ct_index); context_->UpdateConstraintVariableUsage(linear_encoding_ct_index); return true; } @@ -10702,14 +10853,13 @@ void CpModelPresolver::DetectDuplicateColumns() { if (context_->params().keep_all_feasible_solutions_in_presolve()) return; PresolveTimer timer(__FUNCTION__, logger_, time_limit_); - const int num_vars = context_->working_model->variables().size(); - const int num_constraints = context_->working_model->constraints().size(); + const int num_vars = context_->NumVariables(); + const int num_constraints = context_->NumConstraints(); // Our current implementation require almost a full copy. // First construct a transpose var to columns (constraint_index, coeff). - std::vector flat_vars; - std::vector> flat_terms; - CompactVectorVector> var_to_columns; + CompactVectorVectorBuilder> + var_to_columns_builder; // We will only support columns that include: // - objective @@ -10725,7 +10875,7 @@ void CpModelPresolver::DetectDuplicateColumns() { std::vector appear_in_amo(num_vars, false); std::vector appear_in_bool_constraint(num_vars, false); for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); absl::Span literals; bool is_amo = false; @@ -10745,8 +10895,7 @@ void CpModelPresolver::DetectDuplicateColumns() { if (!RefIsPositive(lit)) continue; if (is_amo) appear_in_amo[lit] = true; appear_in_bool_constraint[lit] = true; - flat_vars.push_back(lit); - flat_terms.push_back({c, 1}); + var_to_columns_builder.Add(lit, {c, 1}); } continue; } @@ -10756,8 +10905,7 @@ void CpModelPresolver::DetectDuplicateColumns() { for (int i = 0; i < num_terms; ++i) { const int var = ct.linear().vars(i); const int64_t coeff = ct.linear().coeffs(i); - flat_vars.push_back(var); - flat_terms.push_back({c, coeff}); + var_to_columns_builder.Add(var, {c, coeff}); } continue; } @@ -10769,19 +10917,19 @@ void CpModelPresolver::DetectDuplicateColumns() { // It might not be easy to presolve, but we can at least have a single // variable = sum of var appearing only in objective. And we can transfer the // min cost. - if (context_->working_model->has_objective()) { + if (context_->WorkingModel().has_objective()) { context_->WriteObjectiveToProto(); - const int num_terms = context_->working_model->objective().vars().size(); + const int num_terms = context_->WorkingModel().objective().vars().size(); for (int i = 0; i < num_terms; ++i) { - const int var = context_->working_model->objective().vars(i); - const int64_t coeff = context_->working_model->objective().coeffs(i); - flat_vars.push_back(var); - flat_terms.push_back({kObjectiveConstraint, coeff}); + const int var = context_->WorkingModel().objective().vars(i); + const int64_t coeff = context_->WorkingModel().objective().coeffs(i); + var_to_columns_builder.Add(var, {kObjectiveConstraint, coeff}); } } // Now construct the graph. - var_to_columns.ResetFromFlatMapping(flat_vars, flat_terms); + CompactVectorVector> var_to_columns; + var_to_columns.ResetFromBuilder(var_to_columns_builder); // Find duplicate columns using an hash map. // We only consider "full" columns. @@ -10792,8 +10940,7 @@ void CpModelPresolver::DetectDuplicateColumns() { /*capacity=*/num_vars, ColumnHashForDuplicateDetection(&var_to_columns), ColumnEqForDuplicateDetection(&var_to_columns)); - std::vector flat_duplicates; - std::vector flat_representatives; + CompactVectorVectorBuilder rep_to_dups_builder; for (int var = 0; var < var_to_columns.size(); ++var) { const int size_seen = var_to_columns[var].size(); if (size_seen == 0) continue; @@ -10811,15 +10958,13 @@ void CpModelPresolver::DetectDuplicateColumns() { const auto [it, inserted] = duplicates.insert({var, var}); if (!inserted) { - flat_duplicates.push_back(var); - flat_representatives.push_back(it->second); + rep_to_dups_builder.Add(it->second, var); } } // Process duplicates. int num_equivalent_classes = 0; - CompactVectorVector rep_to_dups; - rep_to_dups.ResetFromFlatMapping(flat_representatives, flat_duplicates); + const CompactVectorVector rep_to_dups(rep_to_dups_builder); std::vector> definition; std::vector var_to_remove; std::vector var_to_rep(num_vars, -1); @@ -10885,8 +11030,7 @@ void CpModelPresolver::DetectDuplicateColumns() { absl::flat_hash_set seen; std::vector> new_terms; for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* mutable_ct = - context_->working_model->mutable_constraints(c); + ConstraintProto* mutable_ct = context_->MutableConstraint(c); seen.clear(); new_terms.clear(); @@ -10983,7 +11127,7 @@ void CpModelPresolver::DetectDuplicateConstraints() { PresolveTimer timer(__FUNCTION__, logger_, time_limit_); // We need the objective written for this. - if (context_->working_model->has_objective()) { + if (context_->WorkingModel().has_objective()) { if (!context_->CanonicalizeObjective()) return; context_->WriteObjectiveToProto(); } @@ -10997,20 +11141,18 @@ void CpModelPresolver::DetectDuplicateConstraints() { // TODO(user): We might want to do that earlier so that our count of variable // usage is not biased by duplicate constraints. const std::vector> duplicates = - FindDuplicateConstraints(*context_->working_model); + FindDuplicateConstraints(context_->WorkingModel()); timer.AddCounter("duplicates", duplicates.size()); for (const auto& [dup, rep] : duplicates) { // Note that it is important to look at the type of the representative in // case the constraint became empty. DCHECK_LT(kObjectiveConstraint, 0); - const int type = - rep == kObjectiveConstraint - ? kObjectiveConstraint - : context_->working_model->constraints(rep).constraint_case(); + const int type = rep == kObjectiveConstraint + ? kObjectiveConstraint + : context_->Constraint(rep).constraint_case(); if (type == ConstraintProto::kInterval) { - interval_mapping.resize(context_->working_model->constraints().size(), - -1); + interval_mapping.resize(context_->NumConstraints(), -1); CHECK_EQ(interval_mapping[rep], -1); interval_mapping[dup] = rep; } @@ -11018,17 +11160,15 @@ void CpModelPresolver::DetectDuplicateConstraints() { // For linear constraint, we merge their rhs since it was ignored in the // FindDuplicateConstraints() call. if (type == ConstraintProto::kLinear) { - const Domain rep_domain = ReadDomainFromProto( - context_->working_model->constraints(rep).linear()); - const Domain d = ReadDomainFromProto( - context_->working_model->constraints(dup).linear()); + const Domain rep_domain = + ReadDomainFromProto(context_->Constraint(rep).linear()); + const Domain d = ReadDomainFromProto(context_->Constraint(dup).linear()); if (rep_domain != d) { context_->UpdateRuleStats("duplicate: merged rhs of linear constraint"); const Domain rhs = rep_domain.IntersectionWith(d); if (rhs.IsEmpty()) { - if (!MarkConstraintAsFalse( - context_->working_model->mutable_constraints(rep), - "duplicate: false after merging")) { + if (!MarkConstraintAsFalse(context_->MutableConstraint(rep), + "duplicate: false after merging")) { return; } @@ -11038,43 +11178,27 @@ void CpModelPresolver::DetectDuplicateConstraints() { context_->UpdateConstraintVariableUsage(rep); continue; } - FillDomainInProto(rhs, context_->working_model->mutable_constraints(rep) - ->mutable_linear()); + FillDomainInProto(rhs, + context_->MutableConstraint(rep)->mutable_linear()); } } if (type == kObjectiveConstraint) { context_->UpdateRuleStats( "duplicate: linear constraint parallel to objective"); - const Domain objective_domain = - ReadDomainFromProto(context_->working_model->objective()); - const Domain d = ReadDomainFromProto( - context_->working_model->constraints(dup).linear()); - if (objective_domain != d) { - context_->UpdateRuleStats("duplicate: updated objective domain"); - const Domain new_domain = objective_domain.IntersectionWith(d); - if (new_domain.IsEmpty()) { - return (void)context_->NotifyThatModelIsUnsat( - "Constraint parallel to the objective makes the objective domain " - "empty"); - } - FillDomainInProto(new_domain, - context_->working_model->mutable_objective()); - - // TODO(user): this write/read is a bit unclean, but needed. - context_->ReadObjectiveFromProto(); - } + const Domain d = ReadDomainFromProto(context_->Constraint(dup).linear()); + if (!context_->RestrictObjectiveDomain(d)) return; } // Remove the duplicate constraint. - context_->working_model->mutable_constraints(dup)->Clear(); + context_->ClearConstraint(dup); context_->UpdateConstraintVariableUsage(dup); context_->UpdateRuleStats("duplicate: removed constraint"); } if (!interval_mapping.empty()) { context_->UpdateRuleStats("duplicate: remapped duplicate intervals"); - const int num_constraints = context_->working_model->constraints().size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { bool changed = false; ApplyToAllIntervalIndices( @@ -11085,7 +11209,7 @@ void CpModelPresolver::DetectDuplicateConstraints() { *ref = new_ref; } }, - context_->working_model->mutable_constraints(c)); + context_->MutableConstraint(c)); if (changed) context_->UpdateConstraintVariableUsage(c); } } @@ -11099,7 +11223,7 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( PresolveTimer timer(__FUNCTION__, logger_, time_limit_); // We need the objective written for this. - if (context_->working_model->has_objective()) { + if (context_->WorkingModel().has_objective()) { if (!context_->CanonicalizeObjective()) return; context_->WriteObjectiveToProto(); } @@ -11111,12 +11235,12 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // cte and expr + Y = other_cte, we can see that X is in affine relation with // Y. const std::vector> duplicates_without_enforcement = - FindDuplicateConstraints(*context_->working_model, true); + FindDuplicateConstraints(context_->WorkingModel(), true); timer.AddCounter("without_enforcements", duplicates_without_enforcement.size()); for (const auto& [dup, rep] : duplicates_without_enforcement) { - auto* dup_ct = context_->working_model->mutable_constraints(dup); - auto* rep_ct = context_->working_model->mutable_constraints(rep); + auto* dup_ct = context_->MutableConstraint(dup); + auto* rep_ct = context_->MutableConstraint(rep); if (dup_ct->constraint_case() == ConstraintProto::kInterval) { context_->UpdateRuleStats( @@ -11251,8 +11375,8 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // B, then constraint A is redundant and we can remove it. const int c_a = i == 0 ? dup : rep; const int c_b = i == 0 ? rep : dup; - const auto& ct_a = context_->working_model->constraints(c_a); - const auto& ct_b = context_->working_model->constraints(c_b); + const auto& ct_a = context_->Constraint(c_a); + const auto& ct_b = context_->Constraint(c_b); enforcement_vars.clear(); implications_used.clear(); @@ -11359,7 +11483,6 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( const int proto_lit_b = mapping->GetProtoLiteralFromLiteral(b); context_->AddImplication(proto_lit_a, proto_lit_b); } - context_->UpdateNewConstraintsVariableUsage(); break; } } @@ -11393,10 +11516,9 @@ void CpModelPresolver::DetectDifferentVariables() { // so x - y > exclusion.Max() --> x > y + exclusion.Max(); // OR x - y < exclusion.Min() --> y > x - exclusion.Min(); different_vars.push_back(key); - offsets[key] = {exclusion.Min() == std::numeric_limits::min() - ? std::numeric_limits::max() - : CapAdd(-exclusion.Min(), 1), - CapAdd(exclusion.Max(), 1)}; + offsets[key] = { + exclusion.Min() == kint64min ? kint64max : CapAdd(-exclusion.Min(), 1), + CapAdd(exclusion.Max(), 1)}; }; // Try to find identical linear constraint with incompatible domains. @@ -11420,9 +11542,9 @@ void CpModelPresolver::DetectDifferentVariables() { bool has_all_diff = false; bool has_no_overlap = false; std::vector> hashes; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() == ConstraintProto::kAllDiff) { has_all_diff = true; continue; @@ -11468,12 +11590,10 @@ void CpModelPresolver::DetectDifferentVariables() { if (range.size() > 10) continue; for (int i = 0; i < range.size(); ++i) { - const ConstraintProto& ct1 = - context_->working_model->constraints(range[i].second); + const ConstraintProto& ct1 = context_->Constraint(range[i].second); const int num_terms = ct1.linear().vars().size(); for (int j = i + 1; j < range.size(); ++j) { - const ConstraintProto& ct2 = - context_->working_model->constraints(range[j].second); + const ConstraintProto& ct2 = context_->Constraint(range[j].second); if (ct2.linear().vars().size() != num_terms) continue; if (!ReadDomainFromProto(ct1.linear()) .IntersectionWith(ReadDomainFromProto(ct2.linear())) @@ -11566,7 +11686,7 @@ void CpModelPresolver::DetectDifferentVariables() { absl::flat_hash_set used_var; Model local_model; - const int num_variables = context_->working_model->variables().size(); + const int num_variables = context_->NumVariables(); local_model.GetOrCreate()->Resize(num_variables); auto* graph = local_model.GetOrCreate(); graph->Resize(num_variables); @@ -11593,7 +11713,7 @@ void CpModelPresolver::DetectDifferentVariables() { } } CHECK(graph->DetectEquivalences()); - graph->TransformIntoMaxCliques(&cliques, 1e8); + CHECK(graph->TransformIntoMaxCliques(&cliques, 1e8)); int num_cliques = 0; int64_t cumulative_size = 0; @@ -11609,8 +11729,7 @@ void CpModelPresolver::DetectDifferentVariables() { // scan of the model we just did above, since we had one linear constraint // per entry. const int num_terms = clique.size(); - std::vector sizes(num_terms, - std::numeric_limits::max()); + std::vector sizes(num_terms, kint64max); for (int i = 0; i < num_terms; ++i) { const int v1 = clique[i].Variable().value(); for (int j = i + 1; j < num_terms; ++j) { @@ -11631,7 +11750,7 @@ void CpModelPresolver::DetectDifferentVariables() { // any other. We should probably handle this case better, but for now we // abort. issue = CapAdd(issue, sizes[i]); - if (issue == std::numeric_limits::max()) { + if (issue == kint64max) { context_->UpdateRuleStats("TODO no_overlap: with task always last"); num_greater_than_one = 0; break; @@ -11649,9 +11768,8 @@ void CpModelPresolver::DetectDifferentVariables() { std::vector intervals; for (int i = 0; i < num_terms; ++i) { - intervals.push_back(context_->working_model->constraints().size()); - auto* new_interval = - context_->working_model->add_constraints()->mutable_interval(); + intervals.push_back(context_->NumConstraints()); + auto* new_interval = context_->AddConstraint()->mutable_interval(); new_interval->mutable_start()->set_offset(0); new_interval->mutable_start()->add_coeffs(1); new_interval->mutable_start()->add_vars(clique[i].Variable().value()); @@ -11662,15 +11780,13 @@ void CpModelPresolver::DetectDifferentVariables() { new_interval->mutable_end()->add_coeffs(1); new_interval->mutable_end()->add_vars(clique[i].Variable().value()); } - auto* new_ct = - context_->working_model->add_constraints()->mutable_no_overlap(); + auto* new_ct = context_->AddConstraint()->mutable_no_overlap(); for (const int interval : intervals) { new_ct->add_intervals(interval); } } else { context_->UpdateRuleStats("all_diff: inferred from x != y constraints"); - auto* new_ct = - context_->working_model->add_constraints()->mutable_all_diff(); + auto* new_ct = context_->AddConstraint()->mutable_all_diff(); for (const Literal l : clique) { auto* expr = new_ct->add_exprs(); expr->add_vars(l.Variable().value()); @@ -11683,8 +11799,6 @@ void CpModelPresolver::DetectDifferentVariables() { timer.AddCounter("cliques", num_cliques); timer.AddCounter("size", cumulative_size); } - - context_->UpdateNewConstraintsVariableUsage(); } namespace { @@ -11730,7 +11844,7 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { // reference there. class Storage { public: - explicit Storage(CpModelProto* proto) : proto_(*proto) {} + explicit Storage(CpModelProto const* proto) : proto_(*proto) {} int size() const { return static_cast(proto_.constraints().size()); } absl::Span operator[](int c) const { return absl::MakeSpan(proto_.constraints(c).linear().vars()); @@ -11739,7 +11853,7 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { private: const CpModelProto& proto_; }; - Storage storage(context_->working_model); + Storage storage(&context_->WorkingModel()); InclusionDetector detector(storage, time_limit_); detector.SetWorkLimit(context_->params().presolve_inclusion_work_limit()); @@ -11751,9 +11865,9 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { // TODO(user): maybe we should store this instead of recomputing it. absl::flat_hash_map cached_expr_domain; - const int num_constraints = context_->working_model->constraints().size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() != ConstraintProto::kLinear) continue; // We only look at long enforced constraint to avoid all the linear of size @@ -11782,13 +11896,12 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { ++num_inclusions; // Store the coeff of the subset linear constraint in a map. - const ConstraintProto& subset_ct = - context_->working_model->constraints(subset_c); + const ConstraintProto& subset_ct = context_->Constraint(subset_c); const LinearConstraintProto& subset_lin = subset_ct.linear(); coeff_map.clear(); detector.IncreaseWorkDone(subset_lin.vars().size()); for (int i = 0; i < subset_lin.vars().size(); ++i) { - coeff_map[subset_lin.vars(i)] = subset_lin.coeffs(i); + coeff_map[subset_lin.vars(i)] += subset_lin.coeffs(i); } // We have a perfect match if 'factor_a * subset == factor_b * superset' on @@ -11798,15 +11911,14 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { // Find interesting factor of the subset that cancels terms of the superset. int64_t factor = 0; - int64_t min_pos_factor = std::numeric_limits::max(); - int64_t max_neg_factor = std::numeric_limits::min(); + int64_t min_pos_factor = kint64max; + int64_t max_neg_factor = kint64min; // Lets compute the implied domain of the linear expression // "superset - subset". Note that we actually do not need exact inclusion // for this algorithm to work, but it is an heuristic to not try it with // all pair of constraints. - const ConstraintProto& superset_ct = - context_->working_model->constraints(superset_c); + const ConstraintProto& superset_ct = context_->Constraint(superset_c); const LinearConstraintProto& superset_lin = superset_ct.linear(); int64_t diff_min_activity = 0; int64_t diff_max_activity = 0; @@ -11853,19 +11965,20 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { // Case 1: superset is redundant. // We process this one first as it let us remove the longest constraint. - // - // Important: because of how we computed the inclusion, the diff_domain is - // only valid if none of the enforcement appear in the subset. - // - // TODO(user): Compute the correct infered domain in this case. - if (subset_ct.enforcement_literal().empty()) { + { const Domain implied_superset_domain = subset_rhs.AdditionWith(diff_domain) .IntersectionWith(cached_expr_domain[superset_c]); - if (implied_superset_domain.IsIncludedIn(superset_rhs)) { - context_->UpdateRuleStats( - "linear inclusion: redundant containing constraint"); - context_->working_model->mutable_constraints(superset_c)->Clear(); + if (implied_superset_domain.IsIncludedIn(superset_rhs) && + std::includes(superset_ct.enforcement_literal().begin(), + superset_ct.enforcement_literal().end(), + subset_ct.enforcement_literal().begin(), + subset_ct.enforcement_literal().end())) { + context_->UpdateRuleStats(absl::StrCat( + "linear inclusion: redundant containing constraint", + subset_ct.enforcement_literal().empty() ? "" + : " (with enforcement)")); + context_->ClearConstraint(superset_c); constraint_indices_to_clean.push_back(superset_c); detector.StopProcessingCurrentSuperset(); return; @@ -11873,14 +11986,20 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { } // Case 2: subset is redundant. - if (superset_ct.enforcement_literal().empty()) { + { const Domain implied_subset_domain = superset_rhs.AdditionWith(diff_domain.Negation()) .IntersectionWith(cached_expr_domain[subset_c]); - if (implied_subset_domain.IsIncludedIn(subset_rhs)) { - context_->UpdateRuleStats( - "linear inclusion: redundant included constraint"); - context_->working_model->mutable_constraints(subset_c)->Clear(); + if (implied_subset_domain.IsIncludedIn(subset_rhs) && + std::includes(subset_ct.enforcement_literal().begin(), + subset_ct.enforcement_literal().end(), + superset_ct.enforcement_literal().begin(), + superset_ct.enforcement_literal().end())) { + context_->UpdateRuleStats(absl::StrCat( + "linear inclusion: redundant included constraint", + superset_ct.enforcement_literal().empty() ? "" + : " (with enforcement)")); + context_->ClearConstraint(subset_c); constraint_indices_to_clean.push_back(subset_c); detector.StopProcessingCurrentSubset(); return; @@ -11902,8 +12021,7 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { // TODO(user): can we compute the best factor that make this as tight as // possible instead? that looks doable. bool is_tigher = true; - if (min_pos_factor != std::numeric_limits::max() && - max_neg_factor != std::numeric_limits::min()) { + if (min_pos_factor != kint64max && max_neg_factor != kint64min) { int64_t min_before = 0; int64_t max_before = 0; int64_t min_after = CapProd(best_factor, subset_rhs.FixedValue()); @@ -11925,8 +12043,7 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { if (is_tigher) { context_->UpdateRuleStats("linear inclusion: sparsify superset"); Substitute(-best_factor, coeff_map, subset_rhs, superset_rhs, - context_->working_model->mutable_constraints(superset_c) - ->mutable_linear()); + context_->MutableConstraint(superset_c)->mutable_linear()); constraint_indices_to_clean.push_back(superset_c); detector.StopProcessingCurrentSuperset(); return; @@ -11942,8 +12059,9 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { // Propagate domain on the superset - subset variables. // TODO(user): We can probably still do that if the inclusion is not // perfect. - temp_ct_.Clear(); + DCHECK(temp_ct_.enforcement_literal().empty()); auto* mutable_linear = temp_ct_.mutable_linear(); + mutable_linear->Clear(); for (int i = 0; i < superset_lin.vars().size(); ++i) { const int var = superset_lin.vars(i); const int64_t coeff = superset_lin.coeffs(i); @@ -11964,7 +12082,7 @@ void CpModelPresolver::DetectDominatedLinearConstraints() { // and we have an equality, the subset is redundant! context_->UpdateRuleStats( "linear inclusion: subset + singleton is equality"); - context_->working_model->mutable_constraints(subset_c)->Clear(); + context_->ClearConstraint(subset_c); constraint_indices_to_clean.push_back(subset_c); detector.StopProcessingCurrentSubset(); return; @@ -12002,7 +12120,7 @@ bool CpModelPresolver::RemoveCommonPart( // new variable. int definiting_equation = -1; for (const auto [c, multiple] : block) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (std::abs(multiple) != 1) continue; if (!IsLinearEqualityConstraint(ct)) continue; if (ct.linear().vars().size() != common_var_coeff_map.size() + 1) continue; @@ -12080,8 +12198,7 @@ bool CpModelPresolver::RemoveCommonPart( for (const auto [c, multiple] : block) { if (c == definiting_equation) continue; - auto* mutable_linear = - context_->working_model->mutable_constraints(c)->mutable_linear(); + auto* mutable_linear = context_->MutableConstraint(c)->mutable_linear(); const int num_terms = mutable_linear->vars().size(); int new_size = 0; bool new_var_already_seen = false; @@ -12163,7 +12280,7 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( // We store for each var if the literal was positive or not. absl::flat_hash_map var_in_amo; - for (int x = 0; x < context_->working_model->variables().size(); ++x) { + for (int x = 0; x < context_->NumVariables(); ++x) { // We pick a variable x that appear in some AMO. if (helper->NumAmoForVariable(x) == 0) continue; if (time_limit_->LimitReached()) break; @@ -12173,7 +12290,7 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( timer.TrackSimpleLoop(context_->VarToConstraints(x).size()); for (const int c : context_->VarToConstraints(x)) { if (c < 0) continue; - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() == ConstraintProto::kAtMostOne) { amo_cts.push_back(c); } else if (ct.constraint_case() == ConstraintProto::kExactlyOne) { @@ -12198,10 +12315,9 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( // For determinism. std::sort(amo_cts.begin(), amo_cts.end()); const int random_c = - absl::Uniform(*context_->random(), 0, amo_cts.size()); + absl::Uniform(context_->random(), 0, amo_cts.size()); base_ct_index = amo_cts[random_c]; - const ConstraintProto& ct = - context_->working_model->constraints(base_ct_index); + const ConstraintProto& ct = context_->Constraint(base_ct_index); const auto& literals = ct.constraint_case() == ConstraintProto::kAtMostOne ? ct.at_most_one().literals() : ct.exactly_one().literals(); @@ -12226,7 +12342,7 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( timer.TrackSimpleLoop(context_->VarToConstraints(x).size()); for (const int c : context_->VarToConstraints(x)) { if (c < 0) continue; - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() != ConstraintProto::kLinear) continue; const int num_terms = ct.linear().vars().size(); if (num_terms < 2) continue; @@ -12276,9 +12392,9 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( // We will just greedily compute a big block with a random order. // TODO(user): We could sort by match with the full constraint instead. - std::shuffle(linear_cts.begin(), linear_cts.end(), *context_->random()); + std::shuffle(linear_cts.begin(), linear_cts.end(), context_->random()); for (const int c : linear_cts) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); const int num_terms = ct.linear().vars().size(); timer.TrackSimpleLoop(2 * num_terms); const int64_t x_coeff = x_multiplier * FindVarCoeff(x, ct); @@ -12341,8 +12457,7 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( amo_literals[new_size++] = lit; } if (new_size == amo_literals.size()) { - const ConstraintProto& ct = - context_->working_model->constraints(base_ct_index); + const ConstraintProto& ct = context_->Constraint(base_ct_index); if (ct.constraint_case() == ConstraintProto::kExactlyOne) { context_->UpdateRuleStats("TODO linear matrix: constant rectangle!"); } else { @@ -12350,8 +12465,7 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( "TODO linear matrix: reuse defining constraint"); } } else if (new_size + 1 == amo_literals.size()) { - const ConstraintProto& ct = - context_->working_model->constraints(base_ct_index); + const ConstraintProto& ct = context_->Constraint(base_ct_index); if (ct.constraint_case() == ConstraintProto::kExactlyOne) { context_->UpdateRuleStats("TODO linear matrix: reuse exo constraint"); } @@ -12361,20 +12475,17 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( // Create a new literal that is one iff one of the literal in AMO is one. const int new_var = context_->NewBoolVarWithClause(amo_literals); { - auto* new_exo = - context_->working_model->add_constraints()->mutable_exactly_one(); + auto* new_exo = context_->AddConstraint()->mutable_exactly_one(); new_exo->mutable_literals()->Reserve(amo_literals.size() + 1); for (const int lit : amo_literals) { new_exo->add_literals(lit); } new_exo->add_literals(NegatedRef(new_var)); - context_->UpdateNewConstraintsVariableUsage(); } // Filter the base amo/exo. { - ConstraintProto* ct = - context_->working_model->mutable_constraints(base_ct_index); + ConstraintProto* ct = context_->MutableConstraint(base_ct_index); auto* mutable_literals = ct->constraint_case() == ConstraintProto::kAtMostOne ? ct->mutable_at_most_one()->mutable_literals() @@ -12391,8 +12502,7 @@ void CpModelPresolver::FindBigAtMostOneAndLinearOverlap( // Use this Boolean in all the linear constraints. for (const int c : block_cts) { - auto* mutable_linear = - context_->working_model->mutable_constraints(c)->mutable_linear(); + auto* mutable_linear = context_->MutableConstraint(c)->mutable_linear(); // The removed expression will be (offset + coeff_x * new_bool). int64_t offset = 0; @@ -12455,7 +12565,7 @@ void CpModelPresolver::FindBigVerticalLinearOverlap( int64_t num_blocks = 0; int64_t nz_reduction = 0; absl::flat_hash_map coeff_map; - for (int x = 0; x < context_->working_model->variables().size(); ++x) { + for (int x = 0; x < context_->NumVariables(); ++x) { if (timer.WorkLimitIsReached()) break; bool in_enforcement = false; @@ -12463,7 +12573,7 @@ void CpModelPresolver::FindBigVerticalLinearOverlap( timer.TrackSimpleLoop(context_->VarToConstraints(x).size()); for (const int c : context_->VarToConstraints(x)) { if (c < 0) continue; - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() != ConstraintProto::kLinear) continue; const int num_terms = ct.linear().vars().size(); @@ -12501,7 +12611,7 @@ void CpModelPresolver::FindBigVerticalLinearOverlap( // For determinism. std::sort(linear_cts.begin(), linear_cts.end()); - std::shuffle(linear_cts.begin(), linear_cts.end(), *context_->random()); + std::shuffle(linear_cts.begin(), linear_cts.end(), context_->random()); // Now it is almost the same algo as for FindBigHorizontalLinearOverlap(). // We greedely compute a "common" rectangle using the first constraint @@ -12513,7 +12623,7 @@ void CpModelPresolver::FindBigVerticalLinearOverlap( std::vector> block; std::vector> common_part; for (const int c : linear_cts) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); const int num_terms = ct.linear().vars().size(); timer.TrackSimpleLoop(num_terms); @@ -12566,6 +12676,21 @@ void CpModelPresolver::FindBigVerticalLinearOverlap( ComputeNonZeroReduction(block.size(), coeff_map.size()); if (saved_nz < 30) continue; + // When we have a wide range of coefficient, introducing a new variable can + // hurt the linear relaxation cuts, because we lose a lot of information + // about the integrality while reasoning on the sum. So we do that more + // defensively. + // + // This avoid degrading the perf a lot on the bppc miplib problems. + int64_t min_magnitude = kint64max; + int64_t max_magnitude = 0; + for (const auto [unused, coeff] : coeff_map) { + const int64_t magnitude = std::abs(coeff); + min_magnitude = std::min(min_magnitude, magnitude); + max_magnitude = std::max(max_magnitude, magnitude); + } + if (min_magnitude != max_magnitude && saved_nz < 1'000) continue; + // Fix multiples, currently this contain the coeff of x for each constraint. const int64_t base_x = coeff_map.at(x); for (auto& [c, multipier] : block) { @@ -12599,10 +12724,10 @@ void CpModelPresolver::FindBigHorizontalLinearOverlap( if (context_->params().presolve_inclusion_work_limit() == 0) return; PresolveTimer timer(__FUNCTION__, logger_, time_limit_); - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); std::vector> to_sort; for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() != ConstraintProto::kLinear) continue; const int size = ct.linear().vars().size(); if (size < 5) continue; @@ -12618,8 +12743,7 @@ void CpModelPresolver::FindBigHorizontalLinearOverlap( // On large problem, using and hash_map can be slow, so we use the vector // version and for now fill the map only when doing the change. std::vector var_to_coeff_non_zeros; - std::vector var_to_coeff(context_->working_model->variables_size(), - 0); + std::vector var_to_coeff(context_->NumVariables(), 0); int64_t num_blocks = 0; int64_t nz_reduction = 0; @@ -12633,7 +12757,7 @@ void CpModelPresolver::FindBigHorizontalLinearOverlap( } var_to_coeff_non_zeros.clear(); { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); const int num_terms = ct.linear().vars().size(); timer.TrackSimpleLoop(num_terms); for (int k = 0; k < num_terms; ++k) { @@ -12657,7 +12781,7 @@ void CpModelPresolver::FindBigHorizontalLinearOverlap( if (i == j) continue; const int other_c = sorted_linear[j]; if (other_c < 0) continue; - const ConstraintProto& ct = context_->working_model->constraints(other_c); + const ConstraintProto& ct = context_->Constraint(other_c); // No need to continue if linear is not large enough. const int num_terms = ct.linear().vars().size(); @@ -12714,8 +12838,7 @@ void CpModelPresolver::FindBigHorizontalLinearOverlap( int new_match_size = 0; const int other_c = sorted_linear[index]; - const ConstraintProto& ct = - context_->working_model->constraints(other_c); + const ConstraintProto& ct = context_->Constraint(other_c); const int num_terms = ct.linear().vars().size(); for (int k = 0; k < num_terms; ++k) { if (var_to_coeff[ct.linear().vars(k)] == ct.linear().coeffs(k)) { @@ -12767,9 +12890,9 @@ void CpModelPresolver::FindAlmostIdenticalLinearConstraints() { // Only keep non-enforced linear equality of size > 2. Sort by size. std::vector> to_sort; - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (!IsLinearEqualityConstraint(ct)) continue; if (ct.linear().vars().size() <= 2) continue; @@ -12786,7 +12909,7 @@ void CpModelPresolver::FindAlmostIdenticalLinearConstraints() { // This is similar to what is used by the inclusion detector. std::vector var_to_clear; std::vector>> var_to_ct_coeffs_; - const int num_variables = context_->working_model->variables_size(); + const int num_variables = context_->NumVariables(); var_to_ct_coeffs_.resize(num_variables); int end; @@ -12807,10 +12930,9 @@ void CpModelPresolver::FindAlmostIdenticalLinearConstraints() { var_to_clear.clear(); for (int i = start; i < end; ++i) { const int c = to_sort[i].second; - const LinearConstraintProto& lin = - context_->working_model->constraints(c).linear(); + const LinearConstraintProto& lin = context_->Constraint(c).linear(); const int index = - absl::Uniform(*context_->random(), 0, lin.vars().size()); + absl::Uniform(context_->random(), 0, lin.vars().size()); const int var = lin.vars(index); if (var_to_ct_coeffs_[var].empty()) var_to_clear.push_back(var); var_to_ct_coeffs_[var].push_back({c, lin.coeffs(index)}); @@ -12825,8 +12947,7 @@ void CpModelPresolver::FindAlmostIdenticalLinearConstraints() { for (int i1 = start; i1 < end; ++i1) { if (timer.WorkLimitIsReached()) break; const int c1 = to_sort[i1].second; - const LinearConstraintProto& lin1 = - context_->working_model->constraints(c1).linear(); + const LinearConstraintProto& lin1 = context_->Constraint(c1).linear(); bool skip = false; for (int i = 0; !skip && i < lin1.vars().size(); ++i) { for (const auto [c2, coeff2] : var_to_ct_coeffs_[lin1.vars(i)]) { @@ -12837,10 +12958,9 @@ void CpModelPresolver::FindAlmostIdenticalLinearConstraints() { if (timer.WorkLimitIsReached()) break; // Skip if we processed this earlier and deleted it. - const ConstraintProto& ct2 = context_->working_model->constraints(c2); + const ConstraintProto& ct2 = context_->Constraint(c2); if (ct2.constraint_case() != ConstraintProto::kLinear) continue; - const LinearConstraintProto& lin2 = - context_->working_model->constraints(c2).linear(); + const LinearConstraintProto& lin2 = context_->Constraint(c2).linear(); if (lin2.vars().size() != length) continue; // TODO(user): In practice LinearsDifferAtOneTerm() will abort @@ -12853,9 +12973,9 @@ void CpModelPresolver::FindAlmostIdenticalLinearConstraints() { // The two equalities only differ at one term ! // do c1 -= c2 and presolve c1 right away. // We should detect new affine relation and remove it. - auto* to_modify = context_->working_model->mutable_constraints(c1); - if (!AddLinearConstraintMultiple( - -1, context_->working_model->constraints(c2), to_modify)) { + auto* to_modify = context_->MutableConstraint(c1); + if (!AddLinearConstraintMultiple(-1, context_->Constraint(c2), + to_modify)) { continue; } @@ -12902,9 +13022,9 @@ void CpModelPresolver::ExtractEncodingFromLinear() { // TODO(user): Ideally we want to process exactly_one first in case a // linear constraint is both included in an at_most_one and an exactly_one. std::vector vars; - const int num_constraints = context_->working_model->constraints().size(); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); switch (ct.constraint_case()) { case ConstraintProto::kAtMostOne: { vars.clear(); @@ -12974,8 +13094,7 @@ void CpModelPresolver::ExtractEncodingFromLinear() { detector.DetectInclusions([&](int subset, int superset) { const int subset_c = relevant_constraints[subset]; const int superset_c = relevant_constraints[superset]; - const ConstraintProto& superset_ct = - context_->working_model->constraints(superset_c); + const ConstraintProto& superset_ct = context_->Constraint(superset_c); if (superset_ct.constraint_case() == ConstraintProto::kAtMostOne) { ++num_at_most_one_encodings; } else { @@ -13014,7 +13133,6 @@ void CpModelPresolver::ExtractEncodingFromLinear() { // encoding that will later be re-created during expansion. void CpModelPresolver::LookAtVariableWithDegreeTwo(int var) { CHECK(RefIsPositive(var)); - CHECK(context_->ConstraintVariableGraphIsUpToDate()); if (context_->ModelIsUnsat()) return; if (context_->params().keep_all_feasible_solutions_in_presolve()) return; if (context_->IsFixed(var)) return; @@ -13032,14 +13150,12 @@ void CpModelPresolver::LookAtVariableWithDegreeTwo(int var) { int ct_var = -1; Domain union_of_domain; int num_positive = 0; - std::vector constraint_indices_to_remove; for (const int c : context_->VarToConstraints(var)) { if (c < 0) { abort = true; break; } - constraint_indices_to_remove.push_back(c); - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.enforcement_literal().size() != 1 || PositiveRef(ct.enforcement_literal(0)) != var || ct.constraint_case() != ConstraintProto::kLinear || @@ -13064,15 +13180,18 @@ void CpModelPresolver::LookAtVariableWithDegreeTwo(int var) { if (!context_->IntersectDomainWith(ct_var, union_of_domain)) return; context_->UpdateRuleStats("variables: removable enforcement literal"); - absl::c_sort(constraint_indices_to_remove); // For determinism + + // For determinism. + const auto& set = context_->VarToConstraints(var); + std::vector constraint_indices_to_remove(set.begin(), set.end()); + absl::c_sort(constraint_indices_to_remove); // Note(user): Only one constraint should be enough given how the postsolve // work. However that will not work for the case where we postsolve by solving // the mapping model (debug_postsolve_with_full_solver:true). for (const int c : constraint_indices_to_remove) { - context_->NewMappingConstraint(context_->working_model->constraints(c), - __FILE__, __LINE__); - context_->working_model->mutable_constraints(c)->Clear(); + context_->NewMappingConstraint(context_->Constraint(c), __FILE__, __LINE__); + context_->ClearConstraint(c); context_->UpdateConstraintVariableUsage(c); } context_->MarkVariableAsRemoved(var); @@ -13092,7 +13211,6 @@ absl::Span AtMostOneOrExactlyOneLiterals(const ConstraintProto& ct) { void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { DCHECK(RefIsPositive(var)); - DCHECK(context_->ConstraintVariableGraphIsUpToDate()); if (context_->ModelIsUnsat()) return; if (context_->IsFixed(var)) return; if (context_->VariableWasRemoved(var)) return; @@ -13114,7 +13232,7 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { int c2 = -1; for (const int c : context_->VarToConstraints(var)) { if (c < 0) continue; - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() != ConstraintProto::kAtMostOne && ct.constraint_case() != ConstraintProto::kExactlyOne) { return; @@ -13141,8 +13259,8 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { // individually and just be removed. So we abort here as the code below // is incorrect if new_ct is an empty constraint. context_->tmp_literals.clear(); - int c1_ref = std::numeric_limits::min(); - const ConstraintProto& ct1 = context_->working_model->constraints(c1); + int c1_ref = kint32min; + const ConstraintProto& ct1 = context_->Constraint(c1); if (AtMostOneOrExactlyOneLiterals(ct1).size() <= 1) return; for (const int lit : AtMostOneOrExactlyOneLiterals(ct1)) { if (PositiveRef(lit) == var) { @@ -13151,8 +13269,8 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { context_->tmp_literals.push_back(lit); } } - int c2_ref = std::numeric_limits::min(); - const ConstraintProto& ct2 = context_->working_model->constraints(c2); + int c2_ref = kint32min; + const ConstraintProto& ct2 = context_->Constraint(c2); if (AtMostOneOrExactlyOneLiterals(ct2).size() <= 1) return; for (const int lit : AtMostOneOrExactlyOneLiterals(ct2)) { if (PositiveRef(lit) == var) { @@ -13161,8 +13279,8 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { context_->tmp_literals.push_back(lit); } } - DCHECK_NE(c1_ref, std::numeric_limits::min()); - DCHECK_NE(c2_ref, std::numeric_limits::min()); + DCHECK_NE(c1_ref, kint32min); + DCHECK_NE(c2_ref, kint32min); if (c1_ref != NegatedRef(c2_ref)) return; // If the cost is non-zero, we can use an exactly one to make it zero. @@ -13201,8 +13319,8 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { ->Assign(literals.begin(), literals.end()); // We can now replace the two constraint by a single one, and delete var! - const int new_ct_index = context_->working_model->constraints().size(); - ConstraintProto* new_ct = context_->working_model->add_constraints(); + const int new_ct_index = context_->NumConstraints(); + ConstraintProto* new_ct = context_->AddConstraint(); if (ct1.constraint_case() == ConstraintProto::kExactlyOne && ct2.constraint_case() == ConstraintProto::kExactlyOne) { for (const int lit : context_->tmp_literals) { @@ -13216,10 +13334,9 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { } } - context_->UpdateNewConstraintsVariableUsage(); - context_->working_model->mutable_constraints(c1)->Clear(); + context_->ClearConstraint(c1); context_->UpdateConstraintVariableUsage(c1); - context_->working_model->mutable_constraints(c2)->Clear(); + context_->ClearConstraint(c2); context_->UpdateConstraintVariableUsage(c2); context_->UpdateRuleStats( @@ -13230,7 +13347,7 @@ void CpModelPresolver::ProcessVariableInTwoAtMostOrExactlyOne(int var) { // negation of other, we need to deal with that right away. For some reason // something is not robust to that it seems. Investigate & fix! DCHECK_NE(new_ct->constraint_case(), ConstraintProto::CONSTRAINT_NOT_SET); - if (PresolveAtMostOrExactlyOne(new_ct)) { + if (PresolveAtMostOrExactlyOne(new_ct, /*use_dual_reduction=*/true)) { context_->UpdateConstraintVariableUsage(new_ct_index); } } @@ -13254,7 +13371,7 @@ void CpModelPresolver::ProcessVariablesOnlyUsedInEncoding() { // For the corner case where a lot of variables are fixed/deleted and this is // called many time, it can make a big difference to skip variables that we // already know are not important. - const int num_variables = context_->working_model->variables().size(); + const int num_variables = context_->NumVariables(); for (int var = encoding_tmp_num_vars_; var < num_variables; ++var) { encoding_tmp_vars_.push_back(var); } @@ -13272,13 +13389,11 @@ void CpModelPresolver::ProcessVariablesOnlyUsedInEncoding() { context_->VariableIsOnlyUsedInEncodingAndMaybeInObjective(var); if (is_only_used_in_encoding) { // Process variables only used in encoding. - const int old_num_constraints = - context_->working_model->constraints_size(); + const int old_num_constraints = context_->NumConstraints(); TryToReplaceVariableByItsEncoding(var, context_, solution_crush_); // Presolve newly created constraints. - for (int c = old_num_constraints; - c < context_->working_model->constraints_size(); ++c) { + for (int c = old_num_constraints; c < context_->NumConstraints(); ++c) { if (PresolveOneConstraint(c)) { context_->UpdateConstraintVariableUsage(c); } @@ -13297,7 +13412,7 @@ void CpModelPresolver::ProcessVariablesOnlyUsedInEncoding() { false; for (const int c : context_->VarToConstraints(var)) { if (c >= 0) { - const ConstraintProto& ct = context_->working_model->constraints(c); + const ConstraintProto& ct = context_->Constraint(c); if (ct.constraint_case() == ConstraintProto::kLinear && ct.linear().vars().size() == 1 && ct.linear().vars(0) == var) { local_model.linear1_constraints.push_back(c); @@ -13335,7 +13450,6 @@ void CpModelPresolver::ProcessVariablesOnlyUsedInEncoding() { void CpModelPresolver::TryToSimplifyDomain(int var) { DCHECK(RefIsPositive(var)); - DCHECK(context_->ConstraintVariableGraphIsUpToDate()); if (context_->ModelIsUnsat()) return; if (context_->IsFixed(var)) return; if (context_->VariableWasRemoved(var)) return; @@ -13375,7 +13489,7 @@ void CpModelPresolver::TryToSimplifyDomain(int var) { // Adds all affine relations to our model for the variables that are still used. void CpModelPresolver::EncodeAllAffineRelations() { int64_t num_added = 0; - for (int var = 0; var < context_->working_model->variables_size(); ++var) { + for (int var = 0; var < context_->NumVariables(); ++var) { if (context_->IsFixed(var)) continue; const AffineRelation::Relation r = context_->GetAffineRelation(var); @@ -13390,7 +13504,7 @@ void CpModelPresolver::EncodeAllAffineRelations() { if (context_->IsFixed(var)) continue; ++num_added; - ConstraintProto* ct = context_->working_model->add_constraints(); + ConstraintProto* ct = context_->AddConstraint(); auto* arg = ct->mutable_linear(); arg->add_vars(var); arg->add_coeffs(1); @@ -13398,7 +13512,6 @@ void CpModelPresolver::EncodeAllAffineRelations() { arg->add_coeffs(-r.coeff); arg->add_domain(r.offset); arg->add_domain(r.offset); - context_->UpdateNewConstraintsVariableUsage(); } // Now that we encoded all remaining affine relation with constraints, we @@ -13439,7 +13552,7 @@ bool CpModelPresolver::ProcessChangedVariables(std::vector* in_queue, // TODO(user): Avoid reprocessing the constraints that changed the domain? if (context_->ModelIsUnsat()) return false; if (time_limit_->LimitReached()) return false; - in_queue->resize(context_->working_model->constraints_size(), false); + in_queue->resize(context_->NumConstraints(), false); const auto& vector_that_can_grow_during_iter = context_->modified_domains.PositionsSetAtLeastOnce(); for (int i = 0; i < vector_that_can_grow_during_iter.size(); ++i) { @@ -13454,11 +13567,10 @@ bool CpModelPresolver::ProcessChangedVariables(std::vector* in_queue, // TODO(user): Integrate these with TryToSimplifyDomain(). if (context_->ModelIsUnsat()) return false; - context_->UpdateNewConstraintsVariableUsage(); if (!context_->CanonicalizeOneObjectiveVariable(v)) return false; - in_queue->resize(context_->working_model->constraints_size(), false); + in_queue->resize(context_->NumConstraints(), false); const int size_before = queue->size(); for (const int c : context_->VarToConstraints(v)) { if (c >= 0 && !(*in_queue)[c]) { @@ -13488,7 +13600,7 @@ void CpModelPresolver::PresolveToFixPoint() { const int64_t max_num_operations = context_->params().debug_max_num_presolve_operations() > 0 ? context_->params().debug_max_num_presolve_operations() - : std::numeric_limits::max(); + : kint64max; // This is used for constraint having unique variables in them (i.e. not // appearing anywhere else) to not call the presolve more than once for this @@ -13496,11 +13608,10 @@ void CpModelPresolver::PresolveToFixPoint() { absl::flat_hash_set> var_constraint_pair_already_called; // The queue of "active" constraints, initialized to the non-empty ones. - std::vector in_queue(context_->working_model->constraints_size(), - false); + std::vector in_queue(context_->NumConstraints(), false); std::deque queue; for (int c = 0; c < in_queue.size(); ++c) { - if (context_->working_model->constraints(c).constraint_case() != + if (context_->Constraint(c).constraint_case() != ConstraintProto::CONSTRAINT_NOT_SET) { in_queue[c] = true; queue.push_back(c); @@ -13512,7 +13623,7 @@ void CpModelPresolver::PresolveToFixPoint() { // In September 2019, experiment on the flatzinc problems shows no changes in // the results. We should actually count the number of rules triggered. if (context_->params().permute_presolve_constraint_order()) { - std::shuffle(queue.begin(), queue.end(), *context_->random()); + std::shuffle(queue.begin(), queue.end(), context_->random()); } else { std::sort(queue.begin(), queue.end(), [this](int a, int b) { const int score_a = context_->ConstraintToVars(a).size(); @@ -13539,21 +13650,17 @@ void CpModelPresolver::PresolveToFixPoint() { in_queue[c] = false; queue.pop_front(); - const int old_num_constraint = - context_->working_model->constraints_size(); + const int old_num_constraint = context_->NumConstraints(); const bool changed = PresolveOneConstraint(c); if (context_->ModelIsUnsat()) { - SOLVER_LOG( - logger_, "Unsat after presolving constraint #", c, - " (warning, dump might be inconsistent): ", - ProtobufShortDebugString(context_->working_model->constraints(c))); + SOLVER_LOG(logger_, "Unsat after presolving constraint #", c, + " (warning, dump might be inconsistent): ", + ProtobufShortDebugString(context_->Constraint(c))); } // Add to the queue any newly created constraints. - const int new_num_constraints = - context_->working_model->constraints_size(); + const int new_num_constraints = context_->NumConstraints(); if (new_num_constraints > old_num_constraint) { - context_->UpdateNewConstraintsVariableUsage(); in_queue.resize(new_num_constraints, true); for (int c = old_num_constraint; c < new_num_constraints; ++c) { queue.push_back(c); @@ -13569,9 +13676,9 @@ void CpModelPresolver::PresolveToFixPoint() { if (context_->ModelIsUnsat()) return; - in_queue.resize(context_->working_model->constraints_size(), false); + in_queue.resize(context_->NumConstraints(), false); const auto& vector_that_can_grow_during_iter = - context_->var_with_reduced_small_degree.PositionsSetAtLeastOnce(); + context_->MutableVarWithReducedSmallDegree()->PositionsSetAtLeastOnce(); for (int i = 0; i < vector_that_can_grow_during_iter.size(); ++i) { const int v = vector_that_can_grow_during_iter[i]; if (context_->VariableIsNotUsedAnymore(v)) continue; @@ -13579,7 +13686,7 @@ void CpModelPresolver::PresolveToFixPoint() { // Remove the variable from the set to allow it to be pushed again. // This is necessary since a few affine logic needs to add the same // variable back to a second pass of processing. - context_->var_with_reduced_small_degree.Clear(v); + context_->MutableVarWithReducedSmallDegree()->Clear(v); // Make sure all affine relations are propagated. // This also remove the relation if the degree is now one. @@ -13592,7 +13699,7 @@ void CpModelPresolver::PresolveToFixPoint() { if (degree == 2 || degree == 3) { // Tricky: this function can add new constraint. ProcessVariableInTwoAtMostOrExactlyOne(v); - in_queue.resize(context_->working_model->constraints_size(), false); + in_queue.resize(context_->NumConstraints(), false); continue; } @@ -13617,7 +13724,7 @@ void CpModelPresolver::PresolveToFixPoint() { queue.push_back(c); } } - context_->var_with_reduced_small_degree.ResetAllToFalse(); + context_->MutableVarWithReducedSmallDegree()->ResetAllToFalse(); if (ProcessChangedVariables(&in_queue, &queue)) continue; @@ -13633,7 +13740,7 @@ void CpModelPresolver::PresolveToFixPoint() { // TODO(user): We can support assumptions but we need to not cut them out // of the feasible region. if (context_->params().keep_all_feasible_solutions_in_presolve()) break; - if (!context_->working_model->assumptions().empty()) break; + if (!context_->WorkingModel().assumptions().empty()) break; // Starts by the "faster" algo that exploit variables that can move freely // in one direction. Or variables that are just blocked by one constraint in @@ -13687,10 +13794,10 @@ void CpModelPresolver::PresolveToFixPoint() { // absent interval in the main propagation loop above. But we currently don't // maintain such list. if (!time_limit_->LimitReached()) { - const int num_constraints = context_->working_model->constraints_size(); + const int num_constraints = context_->NumConstraints(); TimeLimitCheckEveryNCalls bool_or_check_time_limit(100, time_limit_); for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); // We don't want to check the time limit at each "small" constraint as // there could be many. bool check_time_limit = false; @@ -13767,7 +13874,7 @@ void CpModelPresolver::MergeClauses() { // - Register already existing bool_and. // - score at_most_ones literals. // - Record bool_or. - const int num_variables = context_->working_model->variables_size(); + const int num_variables = context_->NumVariables(); std::vector bool_or_indices; std::vector literal_score(2 * num_variables, 0); const auto get_index = [](int ref) { @@ -13777,10 +13884,10 @@ void CpModelPresolver::MergeClauses() { int64_t num_collisions = 0; int64_t num_merges = 0; int64_t num_saved_literals = 0; - ClauseWithOneMissingHasher hasher(*context_->random()); - const int num_constraints = context_->working_model->constraints_size(); + ClauseWithOneMissingHasher hasher(context_->random()); + const int num_constraints = context_->NumConstraints(); for (int c = 0; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->constraint_case() == ConstraintProto::kBoolAnd) { if (ct->enforcement_literal().size() > 1) { // We need to sort the negated literals. @@ -13793,8 +13900,7 @@ void CpModelPresolver::MergeClauses() { to_clean.push_back(c); } else { // See if this is a true duplicate. If yes, merge rhs. - ConstraintProto* other_ct = - context_->working_model->mutable_constraints(it->second); + ConstraintProto* other_ct = context_->MutableConstraint(it->second); const absl::Span s1(ct->enforcement_literal()); const absl::Span s2(other_ct->enforcement_literal()); if (s1 == s2) { @@ -13838,7 +13944,7 @@ void CpModelPresolver::MergeClauses() { } for (const int c : bool_or_indices) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); bool merged = false; timer.TrackSimpleLoop(ct->bool_or().literals().size()); @@ -13849,7 +13955,7 @@ void CpModelPresolver::MergeClauses() { if (it != bool_and_map.end()) { ++num_collisions; const int base_c = it->second; - auto* and_ct = context_->working_model->mutable_constraints(base_c); + auto* and_ct = context_->MutableConstraint(base_c); if (ClauseIsEnforcementImpliesLiteral( ct->bool_or().literals(), and_ct->enforcement_literal(), ref)) { ++num_merges; @@ -13894,7 +14000,7 @@ void CpModelPresolver::MergeClauses() { // Retransform to bool_or bool_and with a single rhs. for (const int c : to_clean) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->bool_and().literals().size() > 1) { context_->UpdateConstraintVariableUsage(c); continue; @@ -13933,9 +14039,10 @@ CpModelPresolver::CpModelPresolver(PresolveContext* context, solution_crush_(context->solution_crush()), logger_(context->logger()), time_limit_(context->time_limit()), - interval_representative_(context->working_model->constraints_size(), - IntervalConstraintHash{context->working_model}, - IntervalConstraintEq{context->working_model}) {} + interval_representative_(context->NumConstraints(), + IntervalConstraintHash{&context->WorkingModel()}, + IntervalConstraintEq{&context->WorkingModel()}) { +} CpSolverStatus CpModelPresolver::InfeasibleStatus() { if (logger_->LoggingIsEnabled()) context_->LogInfo(); @@ -13987,8 +14094,7 @@ void InitializeMappingModelVariables(absl::Span domains, } // namespace void CpModelPresolver::ExpandCpModelAndCanonicalizeConstraints() { - const int num_constraints_before_expansion = - context_->working_model->constraints_size(); + const int num_constraints_before_expansion = context_->NumConstraints(); ExpandCpModel(context_); if (context_->ModelIsUnsat()) return; @@ -13996,10 +14102,9 @@ void CpModelPresolver::ExpandCpModelAndCanonicalizeConstraints() { // These are due to ExpandCpModel() were we create such constraint with // duplicate. The problem is that some code assumes these are presolved // before being called. - const int num_constraints = context_->working_model->constraints().size(); + const int num_constraints = context_->NumConstraints(); for (int c = num_constraints_before_expansion; c < num_constraints; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); - const auto type = ct->constraint_case(); + const auto type = context_->Constraint(c).constraint_case(); if (type == ConstraintProto::kAtMostOne || type == ConstraintProto::kExactlyOne) { if (PresolveOneConstraint(c)) { @@ -14008,7 +14113,7 @@ void CpModelPresolver::ExpandCpModelAndCanonicalizeConstraints() { if (context_->ModelIsUnsat()) return; } else if (type == ConstraintProto::kLinear) { bool changed = false; - if (!CanonicalizeLinear(ct, &changed)) { + if (!CanonicalizeLinear(context_->MutableConstraint(c), &changed)) { return; } if (changed) { @@ -14020,38 +14125,18 @@ void CpModelPresolver::ExpandCpModelAndCanonicalizeConstraints() { namespace { -// Updates the solution hint in the proto with the crushed solution values. -void UpdateHintInProto(PresolveContext* context) { - if (context->ModelIsUnsat()) return; - - SolutionCrush& crush = context->solution_crush(); - if (!crush.SolutionIsLoaded()) return; - const int num_vars = context->working_model->variables().size(); - for (int i = 0; i < num_vars; ++i) { - // If the initial hint is incomplete or infeasible, the crushed hint might - // contain values outside of their respective domains (see SolutionCrush). - crush.SetOrUpdateVarToDomain(i, context->DomainOf(i)); - } - // If the time limit is reached, the presolved model might still contain - // non-representative "affine" variables. - for (int i = 0; i < num_vars; ++i) { - const auto relation = context->GetAffineRelation(i); - if (relation.representative != i) { - crush.SetVarToLinearExpression( - i, {{relation.representative, relation.coeff}}, relation.offset); - } - } - crush.StoreSolutionAsHint(*context->working_model); -} - // Canonicalizes the routes constraints node expressions. In particular, // replaces the variables in these expressions with their representative. void CanonicalizeRoutesConstraintNodeExpressions(PresolveContext* context) { - CpModelProto& proto = *context->working_model; - for (ConstraintProto& ct_ref : *proto.mutable_constraints()) { - if (ct_ref.constraint_case() != ConstraintProto::kRoutes) continue; + const int num_constraints = context->NumConstraints(); + for (int c = 0; c < num_constraints; ++c) { + if (context->Constraint(c).constraint_case() != ConstraintProto::kRoutes) { + continue; + } for (RoutesConstraintProto::NodeExpressions& node_exprs : - *ct_ref.mutable_routes()->mutable_dimensions()) { + *context->MutableConstraint(c) + ->mutable_routes() + ->mutable_dimensions()) { for (LinearExpressionProto& expr : *node_exprs.mutable_exprs()) { context->CanonicalizeLinearExpression({}, &expr); } @@ -14065,8 +14150,8 @@ bool CpModelPresolver::CanonicalizeAllLinears() { if (context_->ModelIsUnsat()) return false; std::queue constraints_to_process; absl::flat_hash_set queued; - for (int c = 0; c < context_->working_model->constraints_size(); ++c) { - ConstraintProto& ct = *context_->working_model->mutable_constraints(c); + for (int c = 0; c < context_->NumConstraints(); ++c) { + ConstraintProto& ct = *context_->MutableConstraint(c); if (ct.constraint_case() != ConstraintProto::kLinear) continue; for (const int v : ct.linear().vars()) { if (context_->IsFixed(v)) { @@ -14082,7 +14167,7 @@ bool CpModelPresolver::CanonicalizeAllLinears() { constraints_to_process.pop(); queued.erase(c); - ConstraintProto& ct = *context_->working_model->mutable_constraints(c); + ConstraintProto& ct = *context_->MutableConstraint(c); if (ct.constraint_case() != ConstraintProto::kLinear) continue; bool has_fixed_vars = false; @@ -14134,9 +14219,6 @@ CpSolverStatus CpModelPresolver::Presolve() { context_->InitializeNewDomains(); if (context_->params().cp_model_pure_sat_presolve()) { - context_->UpdateNewConstraintsVariableUsage(); - DCHECK(context_->ConstraintVariableUsageIsConsistent()); - if (!PresolvePureSatProblem()) { (void)context_->NotifyThatModelIsUnsat( "Proved Infeasible during SAT presolve"); @@ -14154,9 +14236,9 @@ CpSolverStatus CpModelPresolver::Presolve() { // Remove all the unused variables from the presolved model. postsolve_mapping_->clear(); - std::vector mapping(context_->working_model->variables_size(), -1); + std::vector mapping(context_->NumVariables(), -1); int num_unused_variables = 0; - for (int i = 0; i < context_->working_model->variables_size(); ++i) { + for (int i = 0; i < context_->NumVariables(); ++i) { if (mapping[i] != -1) continue; // Already mapped. if (context_->VariableWasRemoved(i)) continue; @@ -14188,7 +14270,8 @@ CpSolverStatus CpModelPresolver::Presolve() { DCHECK(context_->ConstraintVariableUsageIsConsistent()); const int old_size = postsolve_mapping_->size(); - ApplyVariableMapping(absl::MakeSpan(mapping), context_->working_model, + ApplyVariableMapping(absl::MakeSpan(mapping), + context_->UnsafeMutableWorkingModel(), postsolve_mapping_); CHECK_EQ(old_size, postsolve_mapping_->size()); if (context_->lrat_proof_handler != nullptr) { @@ -14207,10 +14290,10 @@ CpSolverStatus CpModelPresolver::Presolve() { // TODO(user): We should probably try to delay this even more. For that we // just need to isolate more the "dual" reduction that usually need to look at // the objective. - if (context_->working_model->has_floating_point_objective()) { + if (context_->WorkingModel().has_floating_point_objective()) { context_->WriteVariableDomainsToProto(); if (!ScaleFloatingPointObjective(context_->params(), logger_, - context_->working_model)) { + context_->UnsafeMutableWorkingModel())) { SOLVER_LOG(logger_, "The floating point objective cannot be scaled with enough " "precision"); @@ -14222,7 +14305,7 @@ CpSolverStatus CpModelPresolver::Presolve() { // we can expose to the user what exact objective we are actually // optimizing. *context_->mapping_model->mutable_objective() = - context_->working_model->objective(); + context_->WorkingModel().objective(); } // If there is a large proprotion of fixed variables, lets remap the model @@ -14232,6 +14315,10 @@ CpSolverStatus CpModelPresolver::Presolve() { // that will be at position i after MaybeRemoveFixedVariables(). If the // mapping is left empty, it will be set to the identity mapping later by // InitializeMappingModelVariables(). + // + // TODO(user): Integrate this with first copy of the main model instead ? It + // needs a bit more work since we don't currently canonicalize and remap at + // the same time. std::vector fixed_postsolve_mapping; if (!MaybeRemoveFixedVariables(&fixed_postsolve_mapping)) { return InfeasibleStatus(); @@ -14241,30 +14328,28 @@ CpSolverStatus CpModelPresolver::Presolve() { // Initialize the objective and the constraint <-> variable graph. // // Note that we did some basic presolving during the first copy of the model. - // This is important has initializing the constraint <-> variable graph can + // This is important since initializing the constraint <-> variable graph can // be costly, so better to remove trivially feasible constraint for instance. context_->InitializeNewDomains(); - context_->LoadSolutionHint(); + if (!context_->solution_crush().SolutionIsLoaded()) { + context_->LoadAndClampSolutionHint(); + } context_->ReadObjectiveFromProto(); if (!context_->CanonicalizeObjective()) return InfeasibleStatus(); - context_->UpdateNewConstraintsVariableUsage(); context_->RegisterVariablesUsedInAssumptions(); DCHECK(context_->ConstraintVariableUsageIsConsistent()); // If presolve is false, just run expansion. if (!context_->params().cp_model_presolve()) { - if (!solution_crush_.SolutionIsLoaded()) { - context_->LoadSolutionHint(); - } ExpandCpModelAndCanonicalizeConstraints(); if (!CanonicalizeAllLinears()) return InfeasibleStatus(); - UpdateHintInProto(context_); + context_->WriteHintToProto(); if (context_->ModelIsUnsat()) return InfeasibleStatus(); // We still write back the canonical objective has we don't deal well // with uninitialized domain or duplicate variables. - if (context_->working_model->has_objective()) { + if (context_->WorkingModel().has_objective()) { context_->WriteObjectiveToProto(); } @@ -14286,8 +14371,9 @@ CpSolverStatus CpModelPresolver::Presolve() { if (!context_->mapping_model->constraints().empty()) { context_->UpdateRuleStats( "TODO: mapping model not empty with presolve disabled"); - context_->working_model->mutable_constraints()->MergeFrom( - context_->mapping_model->constraints()); + for (const ConstraintProto& ct : context_->mapping_model->constraints()) { + *context_->AddConstraint() = ct; + } context_->mapping_model->clear_constraints(); } @@ -14298,7 +14384,7 @@ CpSolverStatus CpModelPresolver::Presolve() { // Presolve all variable domain once. The PresolveToFixPoint() function will // only reprocess domain that changed. if (context_->ModelIsUnsat()) return InfeasibleStatus(); - for (int var = 0; var < context_->working_model->variables().size(); ++var) { + for (int var = 0; var < context_->NumVariables(); ++var) { if (context_->VariableIsNotUsedAnymore(var)) continue; if (!PresolveAffineRelationIfAny(var)) return InfeasibleStatus(); @@ -14307,7 +14393,6 @@ CpSolverStatus CpModelPresolver::Presolve() { // if they already exist in the model. TryToSimplifyDomain(var); if (context_->ModelIsUnsat()) return InfeasibleStatus(); - context_->UpdateNewConstraintsVariableUsage(); } if (!context_->CanonicalizeObjective()) return InfeasibleStatus(); @@ -14326,6 +14411,7 @@ CpSolverStatus CpModelPresolver::Presolve() { // code. See for instance plusexample_6_sat.fzn were represolving the // presolved problem reduces it even more. PresolveToFixPoint(); + DCHECK(context_->ConstraintVariableUsageIsConsistent()); // Call expansion. if (!context_->ModelIsExpanded()) { @@ -14335,10 +14421,10 @@ CpSolverStatus CpModelPresolver::Presolve() { // We need to re-evaluate the degree because some presolve rule only // run after expansion. - const int num_vars = context_->working_model->variables().size(); + const int num_vars = context_->NumVariables(); for (int var = 0; var < num_vars; ++var) { if (context_->VarToConstraints(var).size() <= 3) { - context_->var_with_reduced_small_degree.Set(var); + context_->MutableVarWithReducedSmallDegree()->Set(var); } } } @@ -14361,20 +14447,15 @@ CpSolverStatus CpModelPresolver::Presolve() { DetectDuplicateConstraints(); if (context_->params().keep_symmetry_in_presolve()) { // If the presolve always keep symmetry, we compute it once and for all. - if (!context_->working_model->has_symmetry()) { - DetectAndAddSymmetryToProto(context_->params(), - context_->working_model, logger_, - context_->time_limit()); - } - - // We distinguish an empty symmetry message meaning that symmetry were - // computed and there is none, and the absence of symmetry message - // meaning we don't know. // - // TODO(user): Maybe this is a bit brittle. Also move this logic to - // DetectAndAddSymmetryToProto() ? - if (!context_->working_model->has_symmetry()) { - context_->working_model->mutable_symmetry()->Clear(); + // Note that this will always create a "symmetry" message even if it is + // empty. We use that to know that symmetry where computed, and there is + // none. + if (!context_->WorkingModel().has_symmetry()) { + DetectAndAddSymmetryToProto(context_->params(), + context_->WorkingModel(), + context_->MutableWorkingModelSymmetry(), + logger_, context_->time_limit()); } } else if (!context_->params() .keep_all_feasible_solutions_in_presolve()) { @@ -14404,13 +14485,12 @@ CpSolverStatus CpModelPresolver::Presolve() { // and add them to bool_and clauses? this is some sort of small scale // probing, but good for sat presolve and clique later? if (!context_->ModelIsUnsat() && iter == 0) { - const int old_size = context_->working_model->constraints_size(); + const int old_size = context_->NumConstraints(); for (int c = 0; c < old_size; ++c) { - ConstraintProto* ct = context_->working_model->mutable_constraints(c); + ConstraintProto* ct = context_->MutableConstraint(c); if (ct->constraint_case() != ConstraintProto::kLinear) continue; ExtractAtMostOneFromLinear(ct); } - context_->UpdateNewConstraintsVariableUsage(); } if (context_->params().cp_model_probing_level() > 0) { @@ -14433,6 +14513,7 @@ CpSolverStatus CpModelPresolver::Presolve() { DetectDifferentVariables(); ProcessSetPPC(); TransformClausesToExactlyOne(); + SplitNoOverlapAndCumulativeConstraints(); if (!time_limit_->LimitReached() && context_->params().detect_encoded_complex_domain()) { @@ -14446,7 +14527,7 @@ CpSolverStatus CpModelPresolver::Presolve() { FindAlmostIdenticalLinearConstraints(); ActivityBoundHelper activity_amo_helper; - activity_amo_helper.AddAllAtMostOnes(*context_->working_model); + activity_amo_helper.AddAllAtMostOnes(context_->WorkingModel()); FindBigAtMostOneAndLinearOverlap(&activity_amo_helper); // Heuristic: vertical introduce smaller defining constraint and appear in @@ -14487,7 +14568,7 @@ CpSolverStatus CpModelPresolver::Presolve() { // Tries to spread the objective amongst many variables. // We re-do a canonicalization with the final linear expression. - if (context_->working_model->has_objective()) { + if (context_->WorkingModel().has_objective()) { if (!context_->params().keep_symmetry_in_presolve()) { ExpandObjective(); if (!context_->modified_domains.PositionsSetAtLeastOnce().empty()) { @@ -14508,14 +14589,14 @@ CpSolverStatus CpModelPresolver::Presolve() { // Take care of linear constraint with a complex rhs. FinalExpansionForLinearConstraint(context_); - // Adds all needed affine relation to context_->working_model. + // Adds all needed affine relation to working_model. EncodeAllAffineRelations(); if (context_->ModelIsUnsat()) return InfeasibleStatus(); // If we have symmetry information, lets filter it. - if (context_->working_model->has_symmetry()) { + if (context_->WorkingModel().has_symmetry()) { if (!FilterOrbitOnUnusedOrFixedVariables( - context_->working_model->mutable_symmetry(), context_)) { + context_->MutableWorkingModelSymmetry(), context_)) { return InfeasibleStatus(); } } @@ -14530,7 +14611,7 @@ CpSolverStatus CpModelPresolver::Presolve() { // affine transformation in order to preserve the order. absl::flat_hash_set used_variables; for (DecisionStrategyProto& strategy : - *context_->working_model->mutable_search_strategy()) { + *context_->UnsafeMutableWorkingModel()->mutable_search_strategy()) { CHECK(strategy.variables().empty()); if (strategy.exprs().empty()) continue; @@ -14558,12 +14639,12 @@ CpSolverStatus CpModelPresolver::Presolve() { // Some vars may have been fixed by the affine relations. This may can impact // the objective. Let's re-do the canonicalization. - if (context_->working_model->has_objective()) { + if (context_->WorkingModel().has_objective()) { // We re-do a canonicalization with the final linear expression. if (!context_->CanonicalizeObjective()) return InfeasibleStatus(); context_->WriteObjectiveToProto(); DCHECK(absl::c_all_of( - context_->working_model->objective().vars(), + context_->WorkingModel().objective().vars(), [ctx = context_](int var) { return !ctx->IsFixed(var); })); } @@ -14574,10 +14655,10 @@ CpSolverStatus CpModelPresolver::Presolve() { // Remove all the unused variables from the presolved model. postsolve_mapping_->clear(); - std::vector mapping(context_->working_model->variables_size(), -1); + std::vector mapping(context_->NumVariables(), -1); absl::flat_hash_map constant_to_index; int num_unused_variables = 0; - for (int i = 0; i < context_->working_model->variables_size(); ++i) { + for (int i = 0; i < context_->NumVariables(); ++i) { if (mapping[i] != -1) continue; // Already mapped. if (context_->VariableWasRemoved(i)) { @@ -14633,9 +14714,12 @@ CpSolverStatus CpModelPresolver::Presolve() { DCHECK(context_->ConstraintVariableUsageIsConsistent()); CanonicalizeRoutesConstraintNodeExpressions(context_); - UpdateHintInProto(context_); + context_->WriteHintToProto(); + + // Context shouldn't really be used after this since everything was remapped. const int old_size = postsolve_mapping_->size(); - ApplyVariableMapping(absl::MakeSpan(mapping), context_->working_model, + ApplyVariableMapping(absl::MakeSpan(mapping), + context_->UnsafeMutableWorkingModel(), postsolve_mapping_); CHECK_EQ(old_size, postsolve_mapping_->size()); @@ -14657,8 +14741,8 @@ void CpModelPresolver::MaybePermuteVariablesRandomly( const int n = postsolve_mapping_->size(); std::vector perm(n); std::iota(perm.begin(), perm.end(), 0); - std::shuffle(perm.begin(), perm.end(), *context_->random()); - for (int i = 0; i < context_->working_model->variables_size(); ++i) { + std::shuffle(perm.begin(), perm.end(), context_->random()); + for (int i = 0; i < context_->NumVariables(); ++i) { if (mapping[i] != -1) mapping[i] = perm[mapping[i]]; } std::vector new_postsolve_mapping(n); @@ -14677,7 +14761,7 @@ CpSolverStatus CpModelPresolver::LogAndValidatePresolvedModel() { // situation that causes the final model to be invalid. { const std::string error = - ValidateCpModel(*context_->working_model, /*after_presolve=*/true); + ValidateCpModel(context_->WorkingModel(), /*after_presolve=*/true); if (!error.empty()) { SOLVER_LOG(logger_, "Error while validating postsolved model: ", error); return CpSolverStatus::MODEL_INVALID; @@ -14864,7 +14948,7 @@ bool CpModelPresolver::MaybeRemoveFixedVariables( if (context_->ModelIsUnsat()) return false; // Initialize the mapping to remove all fixed variables. - const int num_vars = context_->working_model->variables().size(); + const int num_vars = context_->NumVariables(); std::vector mapping(num_vars, -1); for (int i = 0; i < num_vars; ++i) { if (context_->IsFixed(i)) continue; @@ -14882,7 +14966,7 @@ bool CpModelPresolver::MaybeRemoveFixedVariables( // TODO(user): Right now the copy does not remove fixed variables from the // objective, but ReadObjectiveFromProto() does it. Maybe we should just not // copy them in the first place. - if (context_->working_model->has_objective()) { + if (context_->WorkingModel().has_objective()) { context_->ReadObjectiveFromProto(); if (!context_->CanonicalizeObjective()) return false; if (!PropagateObjective()) return false; @@ -14894,7 +14978,7 @@ bool CpModelPresolver::MaybeRemoveFixedVariables( // Note that we are not sure the domain where properly written. context_->WriteVariableDomainsToProto(); *context_->mapping_model->mutable_variables() = - context_->working_model->variables(); + context_->WorkingModel().variables(); SOLVER_LOG(logger_, "Large number of fixed variables ", FormatCounter(num_fixed), " / ", FormatCounter(num_vars), @@ -14904,7 +14988,8 @@ bool CpModelPresolver::MaybeRemoveFixedVariables( // Perform the actual mapping. // Note that this might re-add fixed variable that are still used. const int old_size = postsolve_mapping->size(); - ApplyVariableMapping(absl::MakeSpan(mapping), context_->working_model, + ApplyVariableMapping(absl::MakeSpan(mapping), + context_->UnsafeMutableWorkingModel(), postsolve_mapping); if (postsolve_mapping->size() > old_size) { const int new_extra = postsolve_mapping->size() - old_size; @@ -14929,16 +15014,16 @@ ConstraintProto CopyObjectiveForDuplicateDetection( } struct ConstraintHashForDuplicateDetection { - const CpModelProto* working_model; + const CpModelProto& cp_model; bool ignore_enforcement; ConstraintProto objective_constraint; ConstraintHashForDuplicateDetection(const CpModelProto* working_model, bool ignore_enforcement) - : working_model(working_model), + : cp_model(*working_model), ignore_enforcement(ignore_enforcement), objective_constraint( - CopyObjectiveForDuplicateDetection(working_model->objective())) {} + CopyObjectiveForDuplicateDetection(cp_model.objective())) {} // We hash our mostly frequently used constraint directly without extra memory // allocation. We revert to a generic code using proto serialization for the @@ -14946,7 +15031,7 @@ struct ConstraintHashForDuplicateDetection { std::size_t operator()(int ct_idx) const { const ConstraintProto& ct = ct_idx == kObjectiveConstraint ? objective_constraint - : working_model->constraints(ct_idx); + : cp_model.constraints(ct_idx); const std::pair> type_and_enforcement = {ct.constraint_case(), ignore_enforcement @@ -14978,6 +15063,17 @@ struct ConstraintHashForDuplicateDetection { case ConstraintProto::kExactlyOne: return absl::HashOf(type_and_enforcement, absl::MakeSpan(ct.exactly_one().literals())); + case ConstraintProto::kInterval: + return absl::HashOf(type_and_enforcement, + absl::MakeSpan(ct.interval().start().vars()), + absl::MakeSpan(ct.interval().start().coeffs()), + ct.interval().start().offset(), + absl::MakeSpan(ct.interval().size().vars()), + absl::MakeSpan(ct.interval().size().coeffs()), + ct.interval().size().offset(), + absl::MakeSpan(ct.interval().end().vars()), + absl::MakeSpan(ct.interval().end().coeffs()), + ct.interval().end().offset()); default: ConstraintProto copy = ct; copy.clear_name(); @@ -14990,16 +15086,16 @@ struct ConstraintHashForDuplicateDetection { }; struct ConstraintEqForDuplicateDetection { - const CpModelProto* working_model; + const CpModelProto& cp_model; bool ignore_enforcement; ConstraintProto objective_constraint; ConstraintEqForDuplicateDetection(const CpModelProto* working_model, bool ignore_enforcement) - : working_model(working_model), + : cp_model(*working_model), ignore_enforcement(ignore_enforcement), objective_constraint( - CopyObjectiveForDuplicateDetection(working_model->objective())) {} + CopyObjectiveForDuplicateDetection(cp_model.objective())) {} bool operator()(int a, int b) const { if (a == b) { @@ -15007,10 +15103,10 @@ struct ConstraintEqForDuplicateDetection { } const ConstraintProto& ct_a = a == kObjectiveConstraint ? objective_constraint - : working_model->constraints(a); + : cp_model.constraints(a); const ConstraintProto& ct_b = b == kObjectiveConstraint ? objective_constraint - : working_model->constraints(b); + : cp_model.constraints(b); if (ct_a.constraint_case() != ct_b.constraint_case()) return false; if (!ignore_enforcement) { @@ -15043,6 +15139,20 @@ struct ConstraintEqForDuplicateDetection { case ConstraintProto::kExactlyOne: return absl::MakeSpan(ct_a.exactly_one().literals()) == absl::MakeSpan(ct_b.exactly_one().literals()); + case ConstraintProto::kInterval: { + auto compare_linear_expression = [](const LinearExpressionProto& a, + const LinearExpressionProto& b) { + return absl::MakeSpan(a.vars()) == absl::MakeSpan(b.vars()) && + absl::MakeSpan(a.coeffs()) == absl::MakeSpan(b.coeffs()) && + a.offset() == b.offset(); + }; + return compare_linear_expression(ct_a.interval().start(), + ct_b.interval().start()) && + compare_linear_expression(ct_a.interval().size(), + ct_b.interval().size()) && + compare_linear_expression(ct_a.interval().end(), + ct_b.interval().end()); + } default: // Slow (hopefully comparably rare) path. ConstraintProto copy_a = ct_a; diff --git a/ortools/sat/cp_model_presolve.h b/ortools/sat/cp_model_presolve.h index 1993726940c..0bd1ce7162b 100644 --- a/ortools/sat/cp_model_presolve.h +++ b/ortools/sat/cp_model_presolve.h @@ -138,6 +138,7 @@ class CpModelPresolver { bool PresolveIntMod(int c, ConstraintProto* ct); bool PresolveIntProd(ConstraintProto* ct); bool PresolveInterval(int c, ConstraintProto* ct); + bool PresolveLegacyInverse(ConstraintProto* ct); bool PresolveInverse(ConstraintProto* ct); bool DivideLinMaxByGcd(int c, ConstraintProto* ct); bool PresolveLinMax(int c, ConstraintProto* ct); @@ -156,8 +157,8 @@ class CpModelPresolver { bool PresolveCircuit(ConstraintProto* ct); bool PresolveRoutes(ConstraintProto* ct); - bool PresolveAtMostOrExactlyOne(ConstraintProto* ct); - bool PresolveAtMostOne(ConstraintProto* ct); + bool PresolveAtMostOrExactlyOne(ConstraintProto* ct, bool use_dual_reduction); + bool PresolveAtMostOne(ConstraintProto* ct, bool use_dual_reduction = true); bool PresolveExactlyOne(ConstraintProto* ct); bool PresolveBoolAnd(ConstraintProto* ct); @@ -178,7 +179,7 @@ class CpModelPresolver { ABSL_MUST_USE_RESULT bool CanonicalizeAllLinears(); bool PropagateDomainsInLinear(int ct_index, ConstraintProto* ct); bool RemoveSingletonInLinear(ConstraintProto* ct); - bool PresolveSmallLinear(ConstraintProto* ct); + bool PresolveSmallLinear(ConstraintProto* ct, bool canonicalize = true); bool PresolveEmptyLinearConstraint(ConstraintProto* ct); bool PresolveLinearOfSizeOne(ConstraintProto* ct); bool PresolveLinearOfSizeTwo(ConstraintProto* ct); @@ -320,6 +321,11 @@ class CpModelPresolver { // one constraint. void TransformClausesToExactlyOne(); + // Use all the detected precedences to detect if a part of a no_overlap + // constraint can only be executed after the rest and thus the no_overlap + // constraint can be split into smaller no_overlap constraints. + void SplitNoOverlapAndCumulativeConstraints(); + // Converts bool_or and at_most_one of size 2 to bool_and. void ConvertToBoolAnd(); @@ -405,6 +411,9 @@ class CpModelPresolver { void MaybePermuteVariablesRandomly(std::vector& mapping); CpSolverStatus LogAndValidatePresolvedModel(); + void AddLinear2ToModel(const LinearExpression2& linear2, int64_t lb, + int64_t ub); + std::vector* postsolve_mapping_; PresolveContext* context_; SolutionCrush& solution_crush_; @@ -454,7 +463,12 @@ class CpModelPresolver { // // We reuse an IntegerVariable/IntegerValue based class via // GetLinearExpression2FromProto() only visible in the .cc. + // + // We have two versions of this map: one that only consider linear2 that + // are encoded as such in the model, and a more general one that consider any + // linear2 that was detected by the presolve. BestBinaryRelationBounds known_linear2_; + BestBinaryRelationBounds known_model_linear2_; struct IntervalConstraintEq { const CpModelProto* working_model; diff --git a/ortools/sat/cp_model_presolve_test.cc b/ortools/sat/cp_model_presolve_test.cc index 655c11850bd..8291c94a04b 100644 --- a/ortools/sat/cp_model_presolve_test.cc +++ b/ortools/sat/cp_model_presolve_test.cc @@ -179,7 +179,6 @@ CpModelProto PresolveOneConstraint(const CpModelProto& initial_model, PresolveContext context(&model, &presolved_model, &mapping_model); CpModelPresolver presolver(&context, &mapping); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); presolver.PresolveOneConstraint(constraint_index); presolver.RemoveEmptyConstraints(); for (int i = 0; i < presolved_model.variables_size(); ++i) { @@ -4974,7 +4973,6 @@ TEST(PresolveCpModelTest, DetectDifferentVariablesAndAddNoOverlap) { std::vector mapping; CpModelPresolver presolver(&context, &mapping); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); presolver.DetectDifferentVariables(); context.WriteVariableDomainsToProto(); @@ -7401,8 +7399,12 @@ TEST(PresolveCpModelTest, ReduceDomainsInInverse) { variables { domain: [ 0, 2 ] } constraints { inverse { - f_direct: [ 0, 1, 2 ] - f_inverse: [ 3, 4, 5 ] + f_expr_direct: { vars: 0 coeffs: 1 } + f_expr_direct: { vars: 1 coeffs: 1 } + f_expr_direct: { vars: 2 coeffs: 1 } + f_expr_inverse: { vars: 3 coeffs: 1 } + f_expr_inverse: { vars: 4 coeffs: 1 } + f_expr_inverse: { vars: 5 coeffs: 1 } } } )pb"); @@ -8458,7 +8460,6 @@ TEST(PresolveCpModelTest, DuplicateColumns) { CpModelPresolver presolver(&context, &mapping); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); presolver.DetectDuplicateColumns(); context.WriteVariableDomainsToProto(); diff --git a/ortools/sat/cp_model_search.cc b/ortools/sat/cp_model_search.cc index 8d0f3133c96..46e7995d960 100644 --- a/ortools/sat/cp_model_search.cc +++ b/ortools/sat/cp_model_search.cc @@ -30,6 +30,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cp_model_utils.h" @@ -223,7 +224,7 @@ std::function ConstructUserSearchStrategy( return [&view, ¶meters, random, strategies]() { for (const DecisionStrategyProto& strategy : strategies) { int candidate_ref = -1; - int64_t candidate_value = std::numeric_limits::max(); + int64_t candidate_value = kint64max; // TODO(user): Improve the complexity if this becomes an issue which // may be the case if we do a fixed_search. @@ -295,7 +296,7 @@ std::function ConstructUserSearchStrategy( } // Check if one active variable has been found. - if (candidate_value == std::numeric_limits::max()) continue; + if (candidate_value == kint64max) continue; // Pick the winner when decisions are randomized. if (randomize_decision) { @@ -662,6 +663,7 @@ absl::flat_hash_map GetNamedParameters( if (base_params.use_dual_scheduling_heuristics()) { AddExtraSchedulingPropagators(new_params); } + new_params.set_cut_level(0); strategies["probing"] = new_params; new_params.set_linearization_level(0); @@ -669,6 +671,7 @@ absl::flat_hash_map GetNamedParameters( // We want to spend more time on the LP here. new_params.set_linearization_level(2); + new_params.set_cut_level(1); new_params.set_add_lp_constraints_lazily(false); new_params.set_root_lp_iterations(100'000); strategies["probing_max_lp"] = new_params; diff --git a/ortools/sat/cp_model_search_test.cc b/ortools/sat/cp_model_search_test.cc index 155a8778350..c0778fbd78b 100644 --- a/ortools/sat/cp_model_search_test.cc +++ b/ortools/sat/cp_model_search_test.cc @@ -24,6 +24,7 @@ #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/base/parse_test_proto.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" #include "ortools/sat/model.h" @@ -362,7 +363,7 @@ TEST(BasicFixedSearchBehaviorTest, RandomHalfTest) { } EXPECT_EQ(count_by_solution.size(), kNumExpectedSolutions); DoubleDistribution counts; - int min_count = std::numeric_limits::max(); + int min_count = kint32max; std::tuple min_count_solution; int max_count = 0; std::tuple max_count_solution; diff --git a/ortools/sat/cp_model_solver.cc b/ortools/sat/cp_model_solver.cc index bbefa0cb1ca..12be4af46bd 100644 --- a/ortools/sat/cp_model_solver.cc +++ b/ortools/sat/cp_model_solver.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -56,7 +55,8 @@ #include "ortools/base/macros/os_support.h" #include "ortools/base/options.h" #include "ortools/base/timer.h" -#include "ortools/base/version.h" +#include "ortools/base/types.h" +#include "ortools/base/version.h" // IWYU pragma: keep #include "ortools/port/proto_utils.h" #include "ortools/sat/combine_solutions.h" #include "ortools/sat/cp_model.pb.h" @@ -89,6 +89,7 @@ #include "ortools/sat/sat_inprocessing.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" +#include "ortools/sat/scheduling_local_search.h" #include "ortools/sat/shaving_solver.h" #include "ortools/sat/stat_tables.h" #include "ortools/sat/subsolver.h" @@ -571,8 +572,8 @@ std::string CpModelStats(const CpModelProto& model_proto) { } } else { int64_t max_complexity = 0; - int64_t min = std::numeric_limits::max(); - int64_t max = std::numeric_limits::min(); + int64_t min = kint64max; + int64_t max = kint64min; for (const auto& entry : num_vars_per_domains) { min = std::min(min, entry.first.Min()); max = std::max(max, entry.first.Max()); @@ -895,8 +896,7 @@ bool RestrictObjectiveUsingHint(CpModelProto* model_proto) { const int64_t obj_upper_bound = ComputeInnerObjective(model_proto->objective(), solution); - const Domain restriction = - Domain(std::numeric_limits::min(), obj_upper_bound); + const Domain restriction = Domain(kint64min, obj_upper_bound); if (restriction.IsEmpty()) return false; @@ -1142,6 +1142,13 @@ class FullProblemSolver : public SubSolver { RegisterClausesLevelZeroImport(id, shared_->clauses.get(), &local_model_); RegisterClausesExport(id, shared_->clauses.get(), &local_model_); + + // Hack to export all equivalences found so far. + // + // TODO(user): We probably want to do probing "AFTER" this. But then + // we might not want to export all binary clauses found by probing. + local_model_.GetOrCreate() + ->ExportAllEquivalences(); } auto* logger = local_model_.GetOrCreate(); @@ -1435,10 +1442,7 @@ class LnsSolver : public SubSolver { *google::protobuf::Arena::Create(&arena); CpModelProto& mapping_proto = *google::protobuf::Arena::Create(&arena); - auto context = std::make_unique( - &local_model, &lns_fragment, &mapping_proto); - *lns_fragment.mutable_variables() = neighborhood.delta.variables(); std::vector variable_mapping; std::vector fixed_values; { @@ -1470,12 +1474,20 @@ class LnsSolver : public SubSolver { // TODO(user): the mapping removes fixed variables but the model // copy can fix new ones. Should we update the mapping and do a new // copy, and so on until fix point? - std::vector reverse_mapping; - if (!GenerateMapping(context.get(), variable_mapping, reverse_mapping, + if (!GenerateMapping(neighborhood.delta, variable_mapping, fixed_values)) { return; } - ModelCopy copier(context.get(), variable_mapping, reverse_mapping); + ModelCopy copier(&lns_fragment, &local_model, variable_mapping); + if (!copier.ImportVariables(neighborhood.delta)) return; + + if (use_hint) { + if (neighborhood.delta.has_solution_hint()) { + copier.ImportSolutionHint(neighborhood.delta); + } else { + copier.ImportSolutionHint(helper_->ModelProto()); + } + } // Copy and simplify the constraints from the initial model. if (!copier.ImportAndSimplifyConstraints(helper_->ModelProto())) { @@ -1488,14 +1500,6 @@ class LnsSolver : public SubSolver { return; } - if (use_hint) { - if (neighborhood.delta.has_solution_hint()) { - copier.ImportSolutionHint(neighborhood.delta); - } else { - copier.ImportSolutionHint(helper_->ModelProto()); - } - } - // Copy the rest of the model, except symmetries (we don't want to use // the symmetry of the main problem in the LNS presolved problem). if (!copier.ImportEverythingExceptVariablesConstraintsAndHint( @@ -1503,9 +1507,7 @@ class LnsSolver : public SubSolver { return; } - if (!copier.RemapVariablesInProtoAndContext()) { - return; - } + if (!copier.FinishCopy(neighborhood.delta)) return; } lns_fragment.set_name(absl::StrCat("lns_", task_id, "_", source_info)); @@ -1544,6 +1546,9 @@ class LnsSolver : public SubSolver { const int num_vars_before_presolve = lns_fragment.variables_size(); std::vector postsolve_mapping; + + auto context = std::make_unique( + &local_model, &lns_fragment, &mapping_proto); const CpSolverStatus presolve_status = PresolveCpModel(context.get(), &postsolve_mapping); @@ -1615,10 +1620,11 @@ class LnsSolver : public SubSolver { mapping_proto, postsolve_mapping, &local_solution_values); // Map the solution back to the original variables. - const int num_vars = variable_mapping.size(); + const int num_vars = helper_->ModelProto().variables().size(); solution_values.reserve(num_vars); for (int i = 0; i < num_vars; ++i) { - const int mapped_ref = variable_mapping[i]; + const int mapped_ref = + variable_mapping.empty() ? i : variable_mapping[i]; if (mapped_ref != kNoVariableMapping) { int64_t value = local_solution_values[PositiveRef(mapped_ref)]; if (RefIsPositive(mapped_ref)) { @@ -1629,7 +1635,7 @@ class LnsSolver : public SubSolver { solution_values.push_back(1 - value); } } else { - DCHECK_NE(fixed_values[i], std::numeric_limits::min()); + DCHECK_NE(fixed_values[i], kint64min); solution_values.push_back(fixed_values[i]); } } @@ -1776,12 +1782,12 @@ class LnsSolver : public SubSolver { } private: - // Generates a mapping which removes fixed variables (except those in kInverse - // constraints, and one fixed literal). - bool GenerateMapping(PresolveContext* context, + // Generates a mapping which removes fixed variables (except one fixed + // literal). + bool GenerateMapping(CpModelProto& proto_with_variables, std::vector& variable_mapping, - std::vector& reverse_mapping, std::vector& fixed_values) { + int new_var_index = 0; std::vector representatives; if (shared_->clauses != nullptr) { representatives = shared_->clauses->GetRepresentatives(); @@ -1793,39 +1799,45 @@ class LnsSolver : public SubSolver { // Fixed variables can be removed from the model. If a variable is fixed // then the equivalent variables can be fixed too. - const CpModelProto& lns_fragment = *context->working_model; - const int num_vars = lns_fragment.variables_size(); - context->InitializeNewDomains(); - auto fix_literal = [&](int literal, bool value) { - return value ? context->SetLiteralToTrue(literal) - : context->SetLiteralToFalse(literal); + const int num_vars = proto_with_variables.variables_size(); + auto is_fixed = [&](int ref) { + const int var = PositiveRef(ref); + const auto& domain = proto_with_variables.variables(var).domain(); + return domain[0] == domain[domain.size() - 1]; + }; + auto fixed_literal_value = [&](int ref) { + const int var = PositiveRef(ref); + const int value = proto_with_variables.variables(var).domain(0); + return RefIsPositive(ref) ? value : 1 - value; + }; + auto fix_literal = [&](int literal, int value) { + if (!RefIsPositive(literal)) { + literal = PositiveRef(literal); + value = 1 - value; + } + if (is_fixed(literal) && fixed_literal_value(literal) != value) { + return false; + } + auto* domain = + proto_with_variables.mutable_variables(literal)->mutable_domain(); + domain->Clear(); + domain->Add(value); + domain->Add(value); + return true; }; for (int i = 0; i < num_vars; ++i) { - if (context->IsFixed(i)) { + if (proto_with_variables.variables(i).domain().empty()) return false; + if (is_fixed(i)) { const int rep = get_representative(i); if (rep != i) { - if (!fix_literal(rep, context->LiteralIsTrue(i))) return false; + if (!fix_literal(rep, fixed_literal_value(i))) return false; } } } for (int i = 0; i < num_vars; ++i) { const int rep = get_representative(i); - if (rep != i && context->IsFixed(rep)) { - if (!fix_literal(i, context->LiteralIsTrue(rep))) return false; - } - } - - // ModelCopy does not support removing variables appearing in kInverse - // constraints. - absl::flat_hash_set unremovable_vars; - for (const ConstraintProto& ct : lns_fragment.constraints()) { - if (ct.constraint_case() == ConstraintProto::kInverse) { - for (int var : ct.inverse().f_direct()) { - unremovable_vars.insert(var); - } - for (int var : ct.inverse().f_inverse()) { - unremovable_vars.insert(var); - } + if (rep != i && is_fixed(rep)) { + if (!fix_literal(i, fixed_literal_value(rep))) return false; } } @@ -1834,26 +1846,25 @@ class LnsSolver : public SubSolver { fixed_values.reserve(num_vars); for (int i = 0; i < num_vars; ++i) { bool add_to_mapping = true; - if (context->IsFixed(i)) { - const int64_t value = context->FixedValue(i); + if (is_fixed(i)) { + const int64_t value = proto_with_variables.variables(i).domain(0); fixed_values.push_back(value); - add_to_mapping = unremovable_vars.contains(i); + add_to_mapping = false; if (first_fixed_literal == -1 && (value == 0 || value == 1)) { first_fixed_literal = i; add_to_mapping = true; } } else { - fixed_values.push_back(std::numeric_limits::min()); + fixed_values.push_back(kint64min); } if (add_to_mapping && variable_mapping[i] == kNoVariableMapping) { const int rep = get_representative(i); const int rep_var = PositiveRef(rep); if (variable_mapping[rep_var] == kNoVariableMapping) { - variable_mapping[i] = reverse_mapping.size(); + variable_mapping[i] = new_var_index++; variable_mapping[rep_var] = RefIsPositive(rep) ? variable_mapping[i] : NegatedRef(variable_mapping[i]); - reverse_mapping.push_back(i); } else { variable_mapping[i] = RefIsPositive(rep) ? variable_mapping[rep_var] @@ -2019,6 +2030,10 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { lns_params_base, lns_params_stalling, helper, shared)); } + const bool has_no_overlap_or_cumulative = + !helper->TypeToConstraints(ConstraintProto::kNoOverlap).empty() || + !helper->TypeToConstraints(ConstraintProto::kCumulative).empty(); + // Add incomplete subsolvers that require an objective. // // They are all re-entrant, so we do not need to specify more than the number @@ -2073,10 +2088,6 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { lns_params_base, lns_params_stalling, helper, shared)); } - const bool has_no_overlap_or_cumulative = - !helper->TypeToConstraints(ConstraintProto::kNoOverlap).empty() || - !helper->TypeToConstraints(ConstraintProto::kCumulative).empty(); - // Scheduling (no_overlap and cumulative) specific LNS. if (has_no_overlap_or_cumulative) { if (name_filter.Keep("scheduling_intervals_lns")) { @@ -2224,6 +2235,14 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) { } } + if (has_no_overlap_or_cumulative) { + interleaved_subsolvers.push_back( + std::make_unique( + "ls_scheduling", SubSolver::INCOMPLETE, shared->model_proto, + params, shared->time_limit, shared->response, + shared->stat_tables)); + } + if (num_ls_lin > 0) { std::shared_ptr lin_states = std::make_shared(lin_ls_name, params, @@ -2483,12 +2502,13 @@ void FixVariablesToHintValue(const PartialVariableAssignment& solution_hint, PresolveContext* context, SolverLogger* logger) { SOLVER_LOG(logger, "Fixing ", solution_hint.vars().size(), " variables to their value in the solution hints."); + context->InitializeNewDomains(); for (int i = 0; i < solution_hint.vars_size(); ++i) { const int var = solution_hint.vars(i); const int64_t value = solution_hint.values(i); if (!context->IntersectDomainWith(var, Domain(value))) { const IntegerVariableProto& var_proto = - context->working_model->variables(var); + context->WorkingModel().variables(var); const std::string var_name = var_proto.name().empty() ? absl::StrCat("var(", var, ")") : var_proto.name(); @@ -2692,7 +2712,9 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { // Presolve and expansions. SOLVER_LOG(logger, ""); SOLVER_LOG(logger, - absl::StrFormat("Starting presolve at %.2fs", wall_timer->Get())); + absl::StrFormat("Starting initial copy and canonicalization of " + "the input proto at %.2fs", + wall_timer->Get())); // Note: Allocating in an arena significantly speed up destruction (free) for // large messages. @@ -2701,14 +2723,23 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { google::protobuf::Arena::Create(&arena); CpModelProto* mapping_proto = google::protobuf::Arena::Create(&arena); - auto context = std::make_unique(model, new_cp_model_proto, - mapping_proto); - if (!ImportModelWithBasicPresolveIntoContext(model_proto, context.get())) { + // The lrat proof handler is needed is some cases, during the initial copy + // and the presolve. + // + // Note that this is the "presolve one", each worker will have its own. + std::unique_ptr presolve_lrat_proof_handler = + LratProofHandler::MaybeCreate( + model, /*enable_rat_proofs=*/params.cp_model_pure_sat_presolve()); + if (presolve_lrat_proof_handler != nullptr) { + model->Register(presolve_lrat_proof_handler.get()); + } + + if (!CopyModel(model_proto, new_cp_model_proto, model)) { const std::string info = "Problem proved infeasible during initial copy."; SOLVER_LOG(logger, info); - if (context->lrat_proof_handler != nullptr) { - context->lrat_proof_handler->Close(/*model_is_unsat=*/true); + if (model->Mutable() != nullptr) { + model->Mutable()->Close(/*model_is_unsat=*/true); } CpSolverResponse status_response; status_response.set_status(CpSolverStatus::INFEASIBLE); @@ -2734,23 +2765,33 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { " routes constraint(s)."); } - ClearInternalFields(context->working_model, logger); + ClearInternalFields(new_cp_model_proto, logger); if (absl::GetFlag(FLAGS_cp_model_ignore_objective) && - (context->working_model->has_objective() || - context->working_model->has_floating_point_objective())) { + (new_cp_model_proto->has_objective() || + new_cp_model_proto->has_floating_point_objective())) { SOLVER_LOG(logger, "Ignoring objective"); - context->working_model->clear_objective(); - context->working_model->clear_floating_point_objective(); + new_cp_model_proto->clear_objective(); + new_cp_model_proto->clear_floating_point_objective(); } if (absl::GetFlag(FLAGS_cp_model_ignore_hints) && - context->working_model->has_solution_hint()) { + new_cp_model_proto->has_solution_hint()) { SOLVER_LOG(logger, "Ignoring solution hint"); - context->working_model->clear_solution_hint(); + new_cp_model_proto->clear_solution_hint(); } + SOLVER_LOG(logger, + absl::StrFormat("Starting presolve at %.2fs", wall_timer->Get())); + auto context = std::make_unique(model, new_cp_model_proto, + mapping_proto); + // Checks for hints early in case they are forced to be hard constraints. + // + // Note that this still use the original user-given hint, which is not + // clamped. We also don't have hint for potential new variable created during + // copy/canonicalization. But we should be able to recover their value quite + // quickly when we fix the hint. if (params.fix_variables_to_their_hinted_value() && model_proto.has_solution_hint()) { FixVariablesToHintValue(model_proto.solution_hint(), context.get(), logger); @@ -2880,12 +2921,17 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { const CpSolverStatus presolve_status = PresolveCpModel(context.get(), &postsolve_mapping); - // Delete the context as soon as the presolve is done. Note that only - // postsolve_mapping and mapping_proto are needed for postsolve. - if (context->lrat_proof_handler != nullptr) { - context->lrat_proof_handler->Close(presolve_status == + // Delete the presolve_lrat_proof_handler. + // This is needed to properly write the first proof file. + if (presolve_lrat_proof_handler != nullptr) { + presolve_lrat_proof_handler->Close(presolve_status == CpSolverStatus::INFEASIBLE); + model->Unregister(); + presolve_lrat_proof_handler.reset(nullptr); } + + // Delete the context as soon as the presolve is done. Note that only + // postsolve_mapping and mapping_proto are needed for postsolve. context.reset(nullptr); if (presolve_status != CpSolverStatus::UNKNOWN) { @@ -2939,8 +2985,15 @@ CpSolverResponse SolveCpModel(const CpModelProto& model_proto, Model* model) { } else { TimeLimit time_limit; shared_time_limit->UpdateLocalLimit(&time_limit); - DetectAndAddSymmetryToProto(params, new_cp_model_proto, logger, - &time_limit); + DetectAndAddSymmetryToProto(params, *new_cp_model_proto, + new_cp_model_proto->mutable_symmetry(), + logger, &time_limit); + } + + // TODO(user): Some code just check new_cp_model_proto->has_symmetry(). + // If we don't have any generator, better to just clear the field. + if (new_cp_model_proto->symmetry().permutations().empty()) { + new_cp_model_proto->clear_symmetry(); } } diff --git a/ortools/sat/cp_model_solver_helpers.cc b/ortools/sat/cp_model_solver_helpers.cc index 69160ea26ab..b84b9e1b609 100644 --- a/ortools/sat/cp_model_solver_helpers.cc +++ b/ortools/sat/cp_model_solver_helpers.cc @@ -43,7 +43,7 @@ #include "ortools/base/macros/os_support.h" #include "ortools/base/strong_vector.h" #include "ortools/base/timer.h" -#include "ortools/base/version.h" +#include "ortools/base/version.h" // IWYU pragma: keep #include "ortools/graph_base/connected_components.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/clause.h" @@ -1334,8 +1334,32 @@ void LoadCpModel(const CpModelProto& model_proto, Model* model) { if (!sat_solver->FinishPropagation()) return unsat(); } - // Try to extract some structure before we start anything else. - model->GetOrCreate()->EarlyGateDetection(); + if (parameters.use_sat_inprocessing()) { + if (parameters.inprocessing_detect_and_sweep_circuit()) { + auto solve_cp_model_callback = [](const CpModelProto& cp_model_proto) { + Model model; + auto* params = model.GetOrCreate(); + params->set_log_search_progress(false); + params->set_log_to_stdout(false); + params->set_catch_sigint_signal(false); + params->set_linearization_level(0); + params->set_max_time_in_seconds(2); + params->set_cp_model_probing_level(0); + params->set_use_sat_inprocessing(false); + model.GetOrCreate()->ResetLimitFromParameters(*params); + auto* response_manager = model.GetOrCreate(); + response_manager->InitializeObjective(cp_model_proto); + LoadCpModel(cp_model_proto, &model); + SolveLoadedCpModel(cp_model_proto, &model); + return response_manager->GetResponse(); + }; + model->GetOrCreate()->SetSolveCallback( + solve_cp_model_callback); + } + + // Try to extract some structure before we start anything else. + model->GetOrCreate()->EarlyGateDetection(); + } // Note that this is already done in the presolve, but it is important to redo // it here to collect literal => integer >= bound constraints that are used in @@ -1756,9 +1780,75 @@ void QuickSolveWithHint(const CpModelProto& model_proto, Model* model) { // Solve decision problem. ConfigureSearchHeuristics(model); + SatSolver::Status status = SatSolver::Status::LIMIT_REACHED; + const auto& mapping = *model->GetOrCreate(); - const SatSolver::Status status = ResetAndSolveIntegerProblem( - mapping.Literals(model_proto.assumptions()), model); + const auto& integer_trail = *model->GetOrCreate(); + auto* encoder = model->GetOrCreate(); + + // On problems where the propagation is really slow, completing a valid but + // incomplete hint can take hours. This is because we propagate after each + // decision. This is especially true for scheduling or 2D packing. + // + // Here we try first to see if loading the hint as assumptions work since + // propagation will be a lot more efficient in this case. That requires + // creating all relevant literals beforehand though. + // + // Note(user): I am not sure always creating all such associated literal when + // we have a hint is good... especially in LNS subsolvers? that said it might + // orient the solution around the hint. And if the hint was full and valid, + // this is exactly what the HINT_SEARCH would have done. + // + // TODO(user): If we have user assumptions, we can still do that if they align + // with the hint. + // + // TODO(user): For enumerate_all_solutions() we can make this work, but + // currently ExcludeCurrentSolutionAndBacktrack() will not work with all + // assumptions taken at the same decision level. Also we will not have the + // nice analysis of the subset of decisions that are sufficient for exclusion. + if (parameters->try_hint_as_assumptions() && + !parameters->enumerate_all_solutions() && + model_proto.assumptions().empty() && + !shared_response_manager->HasFeasibleSolution()) { + std::vector assumptions; + for (int i = 0; i < model_proto.solution_hint().vars_size(); ++i) { + const int ref = model_proto.solution_hint().vars(i); + const IntegerValue value(model_proto.solution_hint().values(i)); + CHECK(RefIsPositive(ref)); + if (mapping.IsBoolean(ref)) { + assumptions.push_back(value == 1 ? mapping.Literal(ref) + : mapping.Literal(ref).Negated()); + } else { + const IntegerVariable var = mapping.Integer(ref); + const IntegerValue lb = integer_trail.LevelZeroLowerBound(var); + const IntegerValue ub = integer_trail.LevelZeroUpperBound(var); + if (lb == ub) continue; + if (value == lb) { + assumptions.push_back(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, value))); + } else if (value == ub) { + assumptions.push_back(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, value))); + } else { + // TODO(user): Use equality? + assumptions.push_back(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::LowerOrEqual(var, value))); + assumptions.push_back(encoder->GetOrCreateAssociatedLiteral( + IntegerLiteral::GreaterOrEqual(var, value))); + } + } + } + status = ResetAndSolveIntegerProblem(assumptions, model); + } + + // TODO(user): Still do that if we are in ASSUMPTION_UNSAT ? Now that the + // literal are created, normal search will still kind of follow the hint (but + // maybe not as systematically as this). We also do not really need to + // backtrack, we could resume from the state above with a bit of tweaking. + if (status != SatSolver::Status::FEASIBLE) { + status = ResetAndSolveIntegerProblem( + mapping.Literals(model_proto.assumptions()), model); + } const std::string& solution_info = model->Name(); if (status == SatSolver::Status::FEASIBLE) { diff --git a/ortools/sat/cp_model_symmetries.cc b/ortools/sat/cp_model_symmetries.cc index e38ff2b9833..13de4063131 100644 --- a/ortools/sat/cp_model_symmetries.cc +++ b/ortools/sat/cp_model_symmetries.cc @@ -40,6 +40,7 @@ #include "ortools/algorithms/sparse_permutation.h" #include "ortools/base/hash.h" #include "ortools/base/log_severity.h" +#include "ortools/base/types.h" #include "ortools/graph_base/graph.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" @@ -160,7 +161,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( CHECK(initial_equivalence_classes != nullptr); const int num_variables = problem.variables_size(); - auto graph = std::make_unique(); + typename Graph::Builder builder; // Each node will be created with a given color. Two nodes of different color // can never be send one into another by a symmetry. The first element of @@ -176,7 +177,8 @@ std::unique_ptr GenerateGraphForSymmetryDetection( }; IdGenerator color_id_generator; initial_equivalence_classes->clear(); - auto new_node_from_id = [&initial_equivalence_classes, &graph](int color_id) { + auto new_node_from_id = [&initial_equivalence_classes, + &builder](int color_id) { // Since we add nodes one by one, initial_equivalence_classes->size() gives // the number of nodes at any point, which we use as the next node index. const int node = initial_equivalence_classes->size(); @@ -184,7 +186,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( // In some corner cases, we create a node but never uses it. We still // want it to be there. - graph->AddNode(node); + builder.AddNode(node); return node; }; auto new_node = [&new_node_from_id, @@ -223,7 +225,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( // with a given coefficient. absl::flat_hash_map, int> coefficient_nodes; auto get_coefficient_node = - [&new_node_from_id, &graph, &coefficient_nodes, &color_id_generator, + [&new_node_from_id, &builder, &coefficient_nodes, &color_id_generator, &tmp_color, color_id_for_coeff_minus_one](int var, int64_t coeff) { const int var_node = var; DCHECK(RefIsPositive(var)); @@ -247,7 +249,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( color_id = color_id_generator.GetId(tmp_color); } const int secondary_node = new_node_from_id(color_id); - graph->AddArc(var_node, secondary_node); + builder.AddArc(var_node, secondary_node); insert.first->second = secondary_node; return secondary_node; }; @@ -272,7 +274,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( // node. This makes sure that any permutation that touch a variable, must // permute its coefficient nodes accordingly. absl::flat_hash_set> implications; - auto get_implication_node = [&new_node_from_id, &graph, &coefficient_nodes, + auto get_implication_node = [&new_node_from_id, &builder, &coefficient_nodes, color_id_for_coeff_one, color_id_for_coeff_minus_one](int ref) { const int var = PositiveRef(ref); @@ -282,23 +284,23 @@ std::unique_ptr GenerateGraphForSymmetryDetection( if (!insert.second) return insert.first->second; const int secondary_node = new_node_from_id( coeff == 1 ? color_id_for_coeff_one : color_id_for_coeff_minus_one); - graph->AddArc(var, secondary_node); + builder.AddArc(var, secondary_node); insert.first->second = secondary_node; return secondary_node; }; - auto add_implication = [&get_implication_node, &graph, &implications]( + auto add_implication = [&get_implication_node, &builder, &implications]( int ref_a, int ref_b) { const auto insert = implications.insert({ref_a, ref_b}); if (!insert.second) return; - graph->AddArc(get_implication_node(ref_a), get_implication_node(ref_b)); + builder.AddArc(get_implication_node(ref_a), get_implication_node(ref_b)); // Always add the other side. implications.insert({NegatedRef(ref_b), NegatedRef(ref_a)}); - graph->AddArc(get_implication_node(NegatedRef(ref_b)), - get_implication_node(NegatedRef(ref_a))); + builder.AddArc(get_implication_node(NegatedRef(ref_b)), + get_implication_node(NegatedRef(ref_a))); }; - auto make_linear_expr_node = [&new_node, &graph, &get_coefficient_node]( + auto make_linear_expr_node = [&new_node, &builder, &get_coefficient_node]( const LinearExpressionProto& expr, const std::vector& color) { std::vector local_color = color; @@ -310,7 +312,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( const int var_node = PositiveRef(ref); const int64_t coeff = RefIsPositive(ref) ? expr.coeffs(i) : -expr.coeffs(i); - graph->AddArc(get_coefficient_node(var_node, coeff), local_node); + builder.AddArc(get_coefficient_node(var_node, coeff), local_node); } return local_node; }; @@ -357,8 +359,8 @@ std::unique_ptr GenerateGraphForSymmetryDetection( const int64_t coeff = RefIsPositive(ref) ? constraint.linear().coeffs(i) : -constraint.linear().coeffs(i); - graph->AddArc(get_coefficient_node(variable_node, coeff), - constraint_node); + builder.AddArc(get_coefficient_node(variable_node, coeff), + constraint_node); } break; } @@ -366,14 +368,14 @@ std::unique_ptr GenerateGraphForSymmetryDetection( CHECK_EQ(constraint_node, new_node(color)); for (const LinearExpressionProto& expr : constraint.all_diff().exprs()) { - graph->AddArc(shared_linear_expr_node(expr), constraint_node); + builder.AddArc(shared_linear_expr_node(expr), constraint_node); } break; } case ConstraintProto::kBoolOr: { CHECK_EQ(constraint_node, new_node(color)); for (const int ref : constraint.bool_or().literals()) { - graph->AddArc(get_literal_node(ref), constraint_node); + builder.AddArc(get_literal_node(ref), constraint_node); } break; } @@ -388,21 +390,21 @@ std::unique_ptr GenerateGraphForSymmetryDetection( CHECK_EQ(constraint_node, new_node(color)); for (const int ref : constraint.at_most_one().literals()) { - graph->AddArc(get_literal_node(ref), constraint_node); + builder.AddArc(get_literal_node(ref), constraint_node); } break; } case ConstraintProto::kExactlyOne: { CHECK_EQ(constraint_node, new_node(color)); for (const int ref : constraint.exactly_one().literals()) { - graph->AddArc(get_literal_node(ref), constraint_node); + builder.AddArc(get_literal_node(ref), constraint_node); } break; } case ConstraintProto::kBoolXor: { CHECK_EQ(constraint_node, new_node(color)); for (const int ref : constraint.bool_xor().literals()) { - graph->AddArc(get_literal_node(ref), constraint_node); + builder.AddArc(get_literal_node(ref), constraint_node); } break; } @@ -410,7 +412,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( if (constraint.enforcement_literal_size() > 1) { CHECK_EQ(constraint_node, new_node(color)); for (const int ref : constraint.bool_and().literals()) { - graph->AddArc(get_literal_node(ref), constraint_node); + builder.AddArc(get_literal_node(ref), constraint_node); } break; } @@ -431,7 +433,7 @@ std::unique_ptr GenerateGraphForSymmetryDetection( for (int i = 0; i < constraint.lin_max().exprs_size(); ++i) { const LinearExpressionProto& expr = constraint.lin_max().exprs(i); - graph->AddArc(shared_linear_expr_node(expr), target_node); + builder.AddArc(shared_linear_expr_node(expr), target_node); } break; @@ -471,8 +473,8 @@ std::unique_ptr GenerateGraphForSymmetryDetection( // Make sure that if one node is mapped to another one, its other two // components are the same. - graph->AddArc(start_node, end_node); - graph->AddArc(end_node, size_node); + builder.AddArc(start_node, end_node); + builder.AddArc(end_node, size_node); } interval_constraint_index_to_node[constraint_index] = constraint_node; break; @@ -483,8 +485,8 @@ std::unique_ptr GenerateGraphForSymmetryDetection( // here and in a bunch of other places. CHECK_EQ(constraint_node, new_node(color)); for (const int interval : constraint.no_overlap().intervals()) { - graph->AddArc(interval_constraint_index_to_node.at(interval), - constraint_node); + builder.AddArc(interval_constraint_index_to_node.at(interval), + constraint_node); } break; } @@ -499,18 +501,18 @@ std::unique_ptr GenerateGraphForSymmetryDetection( const int node_x = new_node(local_color); const int node_y = new_node(local_color); local_color.pop_back(); - graph->AddArc(constraint_node, node_x); - graph->AddArc(constraint_node, node_y); + builder.AddArc(constraint_node, node_x); + builder.AddArc(constraint_node, node_y); local_color.push_back(1); for (int i = 0; i < size; ++i) { const int box_node = new_node(local_color); - graph->AddArc(box_node, constraint_node); + builder.AddArc(box_node, constraint_node); const int x = constraint.no_overlap_2d().x_intervals(i); const int y = constraint.no_overlap_2d().y_intervals(i); - graph->AddArc(interval_constraint_index_to_node.at(x), node_x); - graph->AddArc(interval_constraint_index_to_node.at(x), box_node); - graph->AddArc(interval_constraint_index_to_node.at(y), node_y); - graph->AddArc(interval_constraint_index_to_node.at(y), box_node); + builder.AddArc(interval_constraint_index_to_node.at(x), node_x); + builder.AddArc(interval_constraint_index_to_node.at(x), box_node); + builder.AddArc(interval_constraint_index_to_node.at(y), node_y); + builder.AddArc(interval_constraint_index_to_node.at(y), box_node); } break; } @@ -522,17 +524,17 @@ std::unique_ptr GenerateGraphForSymmetryDetection( std::vector capacity_color = color; capacity_color.push_back(0); CHECK_EQ(constraint_node, new_node(capacity_color)); - graph->AddArc(constraint_node, - make_linear_expr_node(ct.capacity(), capacity_color)); + builder.AddArc(constraint_node, + make_linear_expr_node(ct.capacity(), capacity_color)); std::vector task_color = color; task_color.push_back(1); for (int i = 0; i < ct.intervals().size(); ++i) { const int task_node = make_linear_expr_node(ct.demands(i), task_color); - graph->AddArc(task_node, constraint_node); - graph->AddArc(task_node, - interval_constraint_index_to_node.at(ct.intervals(i))); + builder.AddArc(task_node, constraint_node); + builder.AddArc(task_node, + interval_constraint_index_to_node.at(ct.intervals(i))); } break; } @@ -564,11 +566,11 @@ std::unique_ptr GenerateGraphForSymmetryDetection( // To make the graph directed, we add two arcs on the head but not on // the tail. if (!constraint.enforcement_literal().empty()) { - graph->AddArc(constraint_node, arc_node); + builder.AddArc(constraint_node, arc_node); } - graph->AddArc(tail_node, arc_node); - graph->AddArc(arc_node, get_literal_node(literal)); - graph->AddArc(arc_node, head_node); + builder.AddArc(tail_node, arc_node); + builder.AddArc(arc_node, get_literal_node(literal)); + builder.AddArc(arc_node, head_node); } break; } @@ -595,12 +597,12 @@ std::unique_ptr GenerateGraphForSymmetryDetection( CHECK_LT(constraint_node, initial_equivalence_classes->size()); } for (const int ref : constraint.enforcement_literal()) { - graph->AddArc(constraint_node, get_literal_node(ref)); + builder.AddArc(constraint_node, get_literal_node(ref)); } } } - graph->Build(); + auto graph = std::move(builder).Build(nullptr); DCHECK_EQ(graph->num_nodes(), initial_equivalence_classes->size()); // TODO(user): The symmetry code does not officially support multi-arcs. And @@ -781,28 +783,25 @@ void LogOrbitInformation(absl::Span var_to_orbit_index, } // namespace void DetectAndAddSymmetryToProto(const SatParameters& params, - CpModelProto* proto, SolverLogger* logger, - TimeLimit* time_limit) { - SymmetryProto* symmetry = proto->mutable_symmetry(); - symmetry->Clear(); + const CpModelProto& proto, + SymmetryProto* mutable_symmetry, + SolverLogger* logger, TimeLimit* time_limit) { + mutable_symmetry->Clear(); std::vector> generators; - FindCpModelSymmetries(params, *proto, &generators, logger, time_limit); - if (generators.empty()) { - proto->clear_symmetry(); - return; - } + FindCpModelSymmetries(params, proto, &generators, logger, time_limit); + if (generators.empty()) return; // Log orbit information. // // TODO(user): It might be nice to just add this to the proto rather than // re-reading the generators and recomputing this in a few places. - const int num_vars = proto->variables().size(); + const int num_vars = proto.variables().size(); const std::vector orbits = GetOrbits(num_vars, generators); LogOrbitInformation(orbits, logger); for (const std::unique_ptr& perm : generators) { - SparsePermutationProto* perm_proto = symmetry->add_permutations(); + SparsePermutationProto* perm_proto = mutable_symmetry->add_permutations(); const int num_cycle = perm->NumCycles(); for (int i = 0; i < num_cycle; ++i) { const int old_size = perm_proto->support().size(); @@ -817,7 +816,7 @@ void DetectAndAddSymmetryToProto(const SatParameters& params, if (orbitope.empty()) return; SOLVER_LOG(logger, "[Symmetry] Found orbitope of size ", orbitope.size(), " x ", orbitope[0].size()); - DenseMatrixProto* matrix = symmetry->add_orbitopes(); + DenseMatrixProto* matrix = mutable_symmetry->add_orbitopes(); matrix->set_num_rows(orbitope.size()); matrix->set_num_cols(orbitope[0].size()); for (const std::vector& row : orbitope) { @@ -948,19 +947,22 @@ std::vector BuildInequalityCoeffsForOrbitope( bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { const SatParameters& params = context->params(); - const CpModelProto& proto = *context->working_model; + const CpModelProto& proto = context->WorkingModel(); // We need to make sure the proto is up to date before computing symmetries! - if (context->working_model->has_objective()) { + if (proto.has_objective()) { context->WriteObjectiveToProto(); } context->WriteVariableDomainsToProto(); // Tricky: the equivalence relation are not part of the proto. - // We thus add them temporarily to compute the symmetry. - int64_t num_added = 0; - const int initial_ct_index = proto.constraints().size(); + // We thus add them temporarily to compute the symmetry, but we can't + // use the context functions while doing that though! + // + // TODO(user): we should be able to always remove all affine relation, so + // this complexity can probably be removed now. const int num_vars = proto.variables_size(); + std::vector> to_add; for (int var = 0; var < num_vars; ++var) { if (context->IsFixed(var)) continue; if (context->VariableWasRemoved(var)) continue; @@ -969,8 +971,16 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { const AffineRelation::Relation r = context->GetAffineRelation(var); if (r.representative == var) continue; + to_add.push_back({var, r}); + } + + // Temporarily add affine relation still required. + int64_t num_added = 0; + const int initial_ct_index = proto.constraints().size(); + CpModelProto* mutable_model = context->UnsafeMutableWorkingModel(); + for (const auto [var, r] : to_add) { ++num_added; - ConstraintProto* ct = context->working_model->add_constraints(); + ConstraintProto* ct = mutable_model->add_constraints(); auto* arg = ct->mutable_linear(); arg->add_vars(var); arg->add_coeffs(1); @@ -980,14 +990,14 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { arg->add_domain(r.offset); } + // Find symmetries. std::vector> generators; FindCpModelSymmetries(params, proto, &generators, context->logger(), context->time_limit()); // Remove temporary affine relation. - context->working_model->mutable_constraints()->DeleteSubrange( - initial_ct_index, num_added); - + mutable_model->mutable_constraints()->DeleteSubrange(initial_ct_index, + num_added); if (generators.empty()) return true; // Collect the at most ones. @@ -1148,7 +1158,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { // orbitope symmetry on one line. But this line being the objective is key. We // can also explicitly look for a full permutation group of the objective // terms directly instead of finding the largest orbitope first. - if (!orbitope.empty() && context->working_model->has_objective()) { + if (!orbitope.empty() && context->WorkingModel().has_objective()) { const int num_objective_terms = context->ObjectiveMap().size(); if (orbitope[0].size() == num_objective_terms) { int num_in_column = 0; @@ -1157,19 +1167,17 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { } if (num_in_column == 1) { context->WriteObjectiveToProto(); - const auto& obj = context->working_model->objective(); + const auto& obj = context->WorkingModel().objective(); CHECK_EQ(num_objective_terms, obj.vars().size()); for (int i = 1; i < num_objective_terms; ++i) { - auto* new_ct = - context->working_model->add_constraints()->mutable_linear(); + auto* new_ct = context->AddConstraint()->mutable_linear(); new_ct->add_vars(obj.vars(i - 1)); new_ct->add_vars(obj.vars(i)); new_ct->add_coeffs(1); new_ct->add_coeffs(-1); new_ct->add_domain(0); - new_ct->add_domain(std::numeric_limits::max()); + new_ct->add_domain(kint64max); } - context->UpdateNewConstraintsVariableUsage(); context->UpdateRuleStats("symmetry: objective is one orbitope row."); return true; } @@ -1241,7 +1249,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { if (orbit_sizes[orbit_index] > num_in_orbit + 1) { context->UpdateRuleStats( "symmetry: added orbit symmetry breaking implications"); - auto* ct = context->working_model->add_constraints(); + auto* ct = context->AddConstraint(); auto* bool_and = ct->mutable_bool_and(); ct->add_enforcement_literal(NegatedRef(distinguished_var)); for (int var = 0; var < num_vars; ++var) { @@ -1250,7 +1258,6 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { if (context->IsFixed(var)) continue; bool_and->add_literals(NegatedRef(var)); } - context->UpdateNewConstraintsVariableUsage(); } return true; } @@ -1524,16 +1531,15 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { continue; } - ConstraintProto* ct = context->working_model->add_constraints(); + ConstraintProto* ct = context->AddConstraint(); ct->mutable_linear()->add_coeffs(1); ct->mutable_linear()->add_vars(orbitope[0][i]); ct->mutable_linear()->add_coeffs(-1); ct->mutable_linear()->add_vars(orbitope[0][i + 1]); ct->mutable_linear()->add_domain(0); - ct->mutable_linear()->add_domain(std::numeric_limits::max()); + ct->mutable_linear()->add_domain(kint64max); context->UpdateRuleStats("symmetry: added symmetry breaking inequality"); } - context->UpdateNewConstraintsVariableUsage(); } else if (orbitope.size() > 1) { std::vector max_values(orbitope.size()); for (int i = 0; i < orbitope.size(); ++i) { @@ -1547,7 +1553,7 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { const std::vector coeffs = BuildInequalityCoeffsForOrbitope( max_values, (int64_t{1} << kMaxBits), &is_approximated); for (int i = 0; i + 1 < orbitope[0].size(); ++i) { - ConstraintProto* ct = context->working_model->add_constraints(); + ConstraintProto* ct = context->AddConstraint(); auto* arg = ct->mutable_linear(); for (int j = 0; j < orbitope.size(); ++j) { const int64_t coeff = coeffs[j]; @@ -1561,8 +1567,8 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { context->MinOf(orbitope[j][i])); } arg->add_domain(0); - arg->add_domain(std::numeric_limits::max()); - DCHECK(!PossibleIntegerOverflow(*context->working_model, arg->vars(), + arg->add_domain(kint64max); + DCHECK(!PossibleIntegerOverflow(context->WorkingModel(), arg->vars(), arg->coeffs())); } context->UpdateRuleStats( @@ -1570,7 +1576,6 @@ bool DetectAndExploitSymmetriesInPresolve(PresolveContext* context) { is_approximated ? "approximated " : "", "inequality ordering orbitope columns"), orbitope[0].size()); - context->UpdateNewConstraintsVariableUsage(); return true; } @@ -1740,7 +1745,7 @@ bool FilterOrbitOnUnusedOrFixedVariables(SymmetryProto* symmetry, // Lets output the new statistics. // TODO(user): Avoid the reconvertion. { - const int num_vars = context->working_model->variables().size(); + const int num_vars = context->WorkingModel().variables().size(); std::vector> generators; for (const SparsePermutationProto& perm : symmetry->permutations()) { generators.emplace_back(CreateSparsePermutationFromProto(num_vars, perm)); diff --git a/ortools/sat/cp_model_symmetries.h b/ortools/sat/cp_model_symmetries.h index d87743f75f2..51ec1099676 100644 --- a/ortools/sat/cp_model_symmetries.h +++ b/ortools/sat/cp_model_symmetries.h @@ -53,7 +53,9 @@ void FindCpModelSymmetries( // Detects symmetries and fill the symmetry field. void DetectAndAddSymmetryToProto(const SatParameters& params, - CpModelProto* proto, SolverLogger* logger, + const CpModelProto& proto, + SymmetryProto* mutable_symmetry, + SolverLogger* logger, TimeLimit* solver_time_limit); // Basic implementation of some symmetry breaking during presolve. diff --git a/ortools/sat/cp_model_symmetries_test.cc b/ortools/sat/cp_model_symmetries_test.cc index 4eaed61aebc..3d411622bc1 100644 --- a/ortools/sat/cp_model_symmetries_test.cc +++ b/ortools/sat/cp_model_symmetries_test.cc @@ -360,7 +360,7 @@ TEST(FindCpModelSymmetries, ImplicationTestThatUsedToFail) { TEST(DetectAndAddSymmetryToProto, BasicTest) { // A model with one (0, 1) (2, 3) symmetry. - CpModelProto model = ParseTestProto(R"pb( + CpModelProto cp_model = ParseTestProto(R"pb( variables { domain: [ 0, 1 ] } variables { domain: [ 0, 1 ] } variables { domain: [ 0, 2 ] } @@ -391,7 +391,8 @@ TEST(DetectAndAddSymmetryToProto, BasicTest) { SatParameters params; params.set_log_search_progress(true); TimeLimit time_limit; - DetectAndAddSymmetryToProto(params, &model, &logger, &time_limit); + DetectAndAddSymmetryToProto(params, cp_model, cp_model.mutable_symmetry(), + &logger, &time_limit); // TODO(user): canonicalize the order in each cycle? const SymmetryProto expected = ParseTestProto(R"pb( @@ -401,7 +402,7 @@ TEST(DetectAndAddSymmetryToProto, BasicTest) { } )pb"); - EXPECT_THAT(model.symmetry(), testing::EqualsProto(expected)); + EXPECT_THAT(cp_model.symmetry(), testing::EqualsProto(expected)); } const char kBooleanModel[] = R"pb( @@ -663,7 +664,6 @@ TEST(FindCpModelSymmetries, BinPacking) { model.GetOrCreate()->SetLogToStdOut(true); PresolveContext context(&model, &proto, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); context.ReadObjectiveFromProto(); EXPECT_TRUE(DetectAndExploitSymmetriesInPresolve(&context)); context.LogInfo(); diff --git a/ortools/sat/cp_model_table.h b/ortools/sat/cp_model_table.h index 0a62e9aa40a..ebb1ae05773 100644 --- a/ortools/sat/cp_model_table.h +++ b/ortools/sat/cp_model_table.h @@ -21,6 +21,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/presolve_context.h" @@ -40,7 +41,7 @@ void CanonicalizeTable(PresolveContext* context, ConstraintProto* ct); // regexps. // // This method is exposed for testing purposes. -constexpr int64_t kTableAnyValue = std::numeric_limits::min(); +constexpr int64_t kTableAnyValue = kint64min; void CompressTuples(absl::Span domain_sizes, std::vector>* tuples); diff --git a/ortools/sat/cp_model_test.cc b/ortools/sat/cp_model_test.cc index 23e275c5c78..b5a25690fa7 100644 --- a/ortools/sat/cp_model_test.cc +++ b/ortools/sat/cp_model_test.cc @@ -28,6 +28,7 @@ #include "ortools/base/gmock.h" #include "ortools/base/log_severity.h" #include "ortools/base/parse_test_proto.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" #include "ortools/sat/cp_model_solver.h" @@ -618,8 +619,7 @@ TEST(CpModelTest, TestGreaterOrEqual) { EXPECT_EQ(0, cp_model.Proto().constraints(0).linear().vars(0)); EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().coeffs(0)); EXPECT_EQ(10, cp_model.Proto().constraints(0).linear().domain(0)); - EXPECT_EQ(std::numeric_limits::max(), - cp_model.Proto().constraints(0).linear().domain(1)); + EXPECT_EQ(kint64max, cp_model.Proto().constraints(0).linear().domain(1)); } TEST(CpModelTest, TestGreater) { @@ -631,8 +631,7 @@ TEST(CpModelTest, TestGreater) { EXPECT_EQ(0, cp_model.Proto().constraints(0).linear().vars(0)); EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().coeffs(0)); EXPECT_EQ(11, cp_model.Proto().constraints(0).linear().domain(0)); - EXPECT_EQ(std::numeric_limits::max(), - cp_model.Proto().constraints(0).linear().domain(1)); + EXPECT_EQ(kint64max, cp_model.Proto().constraints(0).linear().domain(1)); } TEST(CpModelTest, TestLessOrEqual) { @@ -643,8 +642,7 @@ TEST(CpModelTest, TestLessOrEqual) { EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().vars_size()); EXPECT_EQ(0, cp_model.Proto().constraints(0).linear().vars(0)); EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().coeffs(0)); - EXPECT_EQ(std::numeric_limits::min(), - cp_model.Proto().constraints(0).linear().domain(0)); + EXPECT_EQ(kint64min, cp_model.Proto().constraints(0).linear().domain(0)); EXPECT_EQ(10, cp_model.Proto().constraints(0).linear().domain(1)); } @@ -656,8 +654,7 @@ TEST(CpModelTest, TestLess) { EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().vars_size()); EXPECT_EQ(0, cp_model.Proto().constraints(0).linear().vars(0)); EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().coeffs(0)); - EXPECT_EQ(std::numeric_limits::min(), - cp_model.Proto().constraints(0).linear().domain(0)); + EXPECT_EQ(kint64min, cp_model.Proto().constraints(0).linear().domain(0)); EXPECT_EQ(9, cp_model.Proto().constraints(0).linear().domain(1)); } @@ -688,12 +685,10 @@ TEST(CpModelTest, TestNotEqual) { EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().vars(1)); EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().coeffs(0)); EXPECT_EQ(-1, cp_model.Proto().constraints(0).linear().coeffs(1)); - EXPECT_EQ(std::numeric_limits::min(), - cp_model.Proto().constraints(0).linear().domain(0)); + EXPECT_EQ(kint64min, cp_model.Proto().constraints(0).linear().domain(0)); EXPECT_EQ(-1, cp_model.Proto().constraints(0).linear().domain(1)); EXPECT_EQ(1, cp_model.Proto().constraints(0).linear().domain(2)); - EXPECT_EQ(std::numeric_limits::max(), - cp_model.Proto().constraints(0).linear().domain(3)); + EXPECT_EQ(kint64max, cp_model.Proto().constraints(0).linear().domain(3)); } TEST(CpModelTest, TestAllDifferent) { diff --git a/ortools/sat/cp_model_test_utils.cc b/ortools/sat/cp_model_test_utils.cc index 4e2944aad2c..0e20b6c05f7 100644 --- a/ortools/sat/cp_model_test_utils.cc +++ b/ortools/sat/cp_model_test_utils.cc @@ -19,6 +19,7 @@ #include #include "absl/random/random.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" @@ -68,7 +69,7 @@ CpModelProto RandomLinearProblem(int num_variables, int num_constraints) { auto* ct = result.add_constraints()->mutable_linear(); const int min_value = num_variables / 10; ct->add_domain(min_value); - ct->add_domain(std::numeric_limits::max()); + ct->add_domain(kint64max); for (int v = 0; v < num_variables; ++v) { if (absl::Bernoulli(random, 0.5) || // To ensure that the constraint is feasible, we enforce that it has diff --git a/ortools/sat/cp_model_utils.cc b/ortools/sat/cp_model_utils.cc index dc2282e191b..2eb5234fbe1 100644 --- a/ortools/sat/cp_model_utils.cc +++ b/ortools/sat/cp_model_utils.cc @@ -189,8 +189,18 @@ void GetReferencesUsedByConstraint(const ConstraintProto& ct, // The node expressions are not used by the constraint itself. break; case ConstraintProto::ConstraintCase::kInverse: - AddIndices(ct.inverse().f_direct(), variables); - AddIndices(ct.inverse().f_inverse(), variables); + if (!ct.inverse().f_direct().empty()) { + AddIndices(ct.inverse().f_direct(), variables); + AddIndices(ct.inverse().f_inverse(), variables); + } else { + for (const LinearExpressionProto& expr : ct.inverse().f_expr_direct()) { + AddIndices(expr.vars(), variables); + } + for (const LinearExpressionProto& expr : + ct.inverse().f_expr_inverse()) { + AddIndices(expr.vars(), variables); + } + } break; case ConstraintProto::ConstraintCase::kReservoir: for (const LinearExpressionProto& time : ct.reservoir().time_exprs()) { @@ -384,8 +394,17 @@ void ApplyToAllVariableIndices(absl::FunctionRef f, case ConstraintProto::ConstraintCase::kRoutes: break; case ConstraintProto::ConstraintCase::kInverse: - APPLY_TO_REPEATED_FIELD(inverse, f_direct); - APPLY_TO_REPEATED_FIELD(inverse, f_inverse); + if (!ct->inverse().f_direct().empty()) { + APPLY_TO_REPEATED_FIELD(inverse, f_direct); + APPLY_TO_REPEATED_FIELD(inverse, f_inverse); + } else { + for (int i = 0; i < ct->inverse().f_expr_direct_size(); ++i) { + APPLY_TO_REPEATED_FIELD(inverse, f_expr_direct(i)->mutable_vars); + } + for (int i = 0; i < ct->inverse().f_expr_inverse_size(); ++i) { + APPLY_TO_REPEATED_FIELD(inverse, f_expr_inverse(i)->mutable_vars); + } + } break; case ConstraintProto::ConstraintCase::kReservoir: for (int i = 0; i < ct->reservoir().time_exprs_size(); ++i) { @@ -612,6 +631,7 @@ std::vector UsedIntervals(const ConstraintProto& ct) { AddIndices(ct.no_overlap().intervals(), &used_intervals); break; case ConstraintProto::ConstraintCase::kNoOverlap2D: + used_intervals.reserve(2 * ct.no_overlap_2d().x_intervals_size()); AddIndices(ct.no_overlap_2d().x_intervals(), &used_intervals); AddIndices(ct.no_overlap_2d().y_intervals(), &used_intervals); break; @@ -831,8 +851,19 @@ uint64_t FingerprintModel(const CpModelProto& model, uint64_t seed) { fp = FingerprintRepeatedField(ct.routes().literals(), fp); break; case ConstraintProto::ConstraintCase::kInverse: - fp = FingerprintRepeatedField(ct.inverse().f_direct(), fp); - fp = FingerprintRepeatedField(ct.inverse().f_inverse(), fp); + if (!ct.inverse().f_direct().empty()) { + fp = FingerprintRepeatedField(ct.inverse().f_direct(), fp); + fp = FingerprintRepeatedField(ct.inverse().f_inverse(), fp); + } else { + for (const LinearExpressionProto& expr : + ct.inverse().f_expr_direct()) { + fp = FingerprintExpression(expr, fp); + } + for (const LinearExpressionProto& expr : + ct.inverse().f_expr_inverse()) { + fp = FingerprintExpression(expr, fp); + } + } break; case ConstraintProto::ConstraintCase::kReservoir: fp = FingerprintSingleField(ct.reservoir().min_level(), fp); diff --git a/ortools/sat/cp_model_utils.h b/ortools/sat/cp_model_utils.h index f7b589c234d..0775ac18b01 100644 --- a/ortools/sat/cp_model_utils.h +++ b/ortools/sat/cp_model_utils.h @@ -33,6 +33,7 @@ #include "ortools/base/hash.h" #include "ortools/base/macros/os_support.h" #include "ortools/base/options.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/util/bitset.h" #include "ortools/util/sorted_interval_list.h" @@ -195,10 +196,8 @@ std::vector AllValuesInDomain(const ProtoWithDomain& proto) { inline double ScaleObjectiveValue(const CpObjectiveProto& proto, int64_t value) { double result = static_cast(value); - if (value == std::numeric_limits::min()) - result = -std::numeric_limits::infinity(); - if (value == std::numeric_limits::max()) - result = std::numeric_limits::infinity(); + if (value == kint64min) result = -std::numeric_limits::infinity(); + if (value == kint64max) result = std::numeric_limits::infinity(); result += proto.offset(); if (proto.scaling_factor() == 0) return result; return proto.scaling_factor() * result; @@ -362,6 +361,17 @@ void SetupTextFormatPrinter(google::protobuf::TextFormat::Printer* printer); static_assert(!kTargetOsSupportsProtoDescriptor); #endif // ORTOOLS_TARGET_OS_SUPPORTS_PROTO_DESCRIPTOR +#if defined(ORTOOLS_TARGET_OS_SUPPORTS_PROTO_DESCRIPTOR) +template +std::string PrettyPrintModelProto([[maybe_unused]] const M& proto) { + std::string proto_string; + google::protobuf::TextFormat::Printer printer; + SetupTextFormatPrinter(&printer); + printer.PrintToString(proto, &proto_string); + return proto_string; +} +#endif // ORTOOLS_TARGET_OS_SUPPORTS_PROTO_DESCRIPTOR + template bool WriteModelProtoToFile([[maybe_unused]] const M& proto, [[maybe_unused]] absl::string_view filename) { @@ -369,11 +379,9 @@ bool WriteModelProtoToFile([[maybe_unused]] const M& proto, static_assert(kTargetOsSupportsProtoDescriptor); if (absl::EndsWith(filename, "txt") || absl::EndsWith(filename, "textproto")) { - std::string proto_string; - google::protobuf::TextFormat::Printer printer; - SetupTextFormatPrinter(&printer); - printer.PrintToString(proto, &proto_string); - return file::SetContents(filename, proto_string, file::Defaults()).ok(); + return file::SetContents(filename, PrettyPrintModelProto(proto), + file::Defaults()) + .ok(); } else { return file::SetBinaryProto(filename, proto, file::Defaults()).ok(); } diff --git a/ortools/sat/csharp/CMakeLists.txt b/ortools/sat/csharp/CMakeLists.txt index 3c53e5d1b37..e15df8c3349 100644 --- a/ortools/sat/csharp/CMakeLists.txt +++ b/ortools/sat/csharp/CMakeLists.txt @@ -14,7 +14,7 @@ set_property(SOURCE sat.i PROPERTY CPLUSPLUS ON) set_property(SOURCE sat.i PROPERTY SWIG_MODULE_NAME operations_research_sat) set_property(SOURCE sat.i PROPERTY COMPILE_DEFINITIONS - ${OR_TOOLS_COMPILE_DEFINITIONS}) + ${OR_TOOLS_COMPILE_DEFINITIONS} ABSL_MUST_USE_RESULT=) set_property(SOURCE sat.i PROPERTY COMPILE_OPTIONS -namespace ${DOTNET_PROJECT}.Sat -dllimport google-ortools-native) diff --git a/ortools/sat/cumulative_energy.cc b/ortools/sat/cumulative_energy.cc index f3e6845156e..b101c905ee1 100644 --- a/ortools/sat/cumulative_energy.cc +++ b/ortools/sat/cumulative_energy.cc @@ -405,9 +405,9 @@ void CumulativeIsAfterSubsetConstraint::RegisterWith( CumulativeDualFeasibleEnergyConstraint::CumulativeDualFeasibleEnergyConstraint( AffineExpression capacity, SchedulingConstraintHelper* helper, SchedulingDemandHelper* demands, Model* model) - : random_(model->GetOrCreate()), + : random_(*model->GetOrCreate()), shared_stats_(model->GetOrCreate()), - opp_infeasibility_detector_(*random_, shared_stats_), + opp_infeasibility_detector_(random_, shared_stats_), capacity_(capacity), integer_trail_(model->GetOrCreate()), helper_(helper), @@ -656,7 +656,7 @@ bool CumulativeDualFeasibleEnergyConstraint::Propagate() { absl::InlinedVector, 3> sampled_candidates; std::sample(candidates_for_conflict.begin(), candidates_for_conflict.end(), - std::back_inserter(sampled_candidates), 3, *random_); + std::back_inserter(sampled_candidates), 3, random_); for (const auto& [window_start, window_end] : sampled_candidates) { if (!FindAndPropagateConflict(window_start, window_end)) { ++num_conflicts_; diff --git a/ortools/sat/cumulative_energy.h b/ortools/sat/cumulative_energy.h index 14d80ac90b7..3f52174d46c 100644 --- a/ortools/sat/cumulative_energy.h +++ b/ortools/sat/cumulative_energy.h @@ -136,7 +136,7 @@ class CumulativeDualFeasibleEnergyConstraint : public PropagatorInterface { bool FindAndPropagateConflict(IntegerValue window_start, IntegerValue window_end); - ModelRandomGenerator* random_; + absl::BitGenRef random_; SharedStatistics* shared_stats_; OrthogonalPackingInfeasibilityDetector opp_infeasibility_detector_; const AffineExpression capacity_; diff --git a/ortools/sat/cuts.cc b/ortools/sat/cuts.cc index 9b037ea689b..07297eef9db 100644 --- a/ortools/sat/cuts.cc +++ b/ortools/sat/cuts.cc @@ -39,6 +39,7 @@ #include "absl/types/span.h" #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/lp_data/lp_types.h" #include "ortools/sat/clause.h" #include "ortools/sat/implied_bounds.h" @@ -352,8 +353,7 @@ ABSL_DEPRECATED("Only used in tests, this will be removed.") bool CutDataBuilder::ConvertToLinearConstraint(const CutData& cut, LinearConstraint* output) { tmp_map_.clear(); - if (cut.rhs > absl::int128(std::numeric_limits::max()) || - cut.rhs < absl::int128(std::numeric_limits::min())) { + if (cut.rhs > absl::int128(kint64max) || cut.rhs < absl::int128(kint64min)) { return false; } IntegerValue new_rhs = static_cast(cut.rhs); @@ -696,7 +696,7 @@ IntegerValue GetFactorT(IntegerValue rhs_remainder, IntegerValue divisor, // Make sure that when we multiply the rhs or the coefficient by a factor t, // we do not have an integer overflow. Note that the rhs should be counted // in max_magnitude since we will apply f() on it. - IntegerValue max_t(std::numeric_limits::max()); + IntegerValue max_t(kint64max); if (max_magnitude != 0) { max_t = max_t / max_magnitude; } @@ -718,8 +718,7 @@ std::function GetSuperAdditiveRoundingFunction( // Make sure we don't have an integer overflow below. Note that we assume that // divisor and the maximum coeff magnitude are not too different (maybe a // factor 1000 at most) so that the final result will never overflow. - max_scaling = - std::min(max_scaling, std::numeric_limits::max() / divisor); + max_scaling = std::min(max_scaling, kint64max / divisor); const IntegerValue size = divisor - rhs_remainder; if (max_scaling == 1 || size == 1) { @@ -837,9 +836,7 @@ GetSuperAdditiveStrengtheningMirFunction(IntegerValue positive_rhs, } // We need to scale. - scaling = - std::min(scaling, IntegerValue(std::numeric_limits::max()) / - positive_rhs); + scaling = std::min(scaling, IntegerValue(kint64max) / positive_rhs); if (scaling == 1) { return [](IntegerValue v) { if (v >= 0) return IntegerValue(0); @@ -1680,7 +1677,7 @@ bool CoverCutHelper::TrySingleNodeFlow(const CutData& input_ct, // strengthening just result in all coeff at 1, so worse than our cover // heuristic. CHECK_LT(cut_.rhs, 0); - if (cut_.rhs <= absl::int128(std::numeric_limits::min())) { + if (cut_.rhs <= absl::int128(kint64min)) { return false; } @@ -1789,7 +1786,7 @@ bool CoverCutHelper::TryWithLetchfordSouliLifting( // We don't support big rhs here. // Note however than since this only deal with Booleans, it is less likely. - if (cut_.rhs > absl::int128(std::numeric_limits::max())) { + if (cut_.rhs > absl::int128(kint64max)) { ++ls_stats_.num_overflow_aborts; return false; } @@ -1832,7 +1829,7 @@ bool CoverCutHelper::TryWithLetchfordSouliLifting( std::vector thresholds; for (int i = 0; i < q; ++i) { // TODO(user): compute this in an overflow-safe way. - if (CapProd(p.value(), i + 1) >= std::numeric_limits::max() - 1) { + if (CapProd(p.value(), i + 1) >= kint64max - 1) { ++ls_stats_.num_overflow_aborts; return false; } @@ -3012,9 +3009,8 @@ CutGenerator CreateCliqueCutGenerator( // We need to express such "at most one" in term of the initial // variables, so we do not use the // LinearConstraintBuilder::AddLiteralTerm() here. - LinearConstraintBuilder builder( - model, IntegerValue(std::numeric_limits::min()), - IntegerValue(1)); + LinearConstraintBuilder builder(model, IntegerValue(kint64min), + IntegerValue(1)); for (const Literal l : at_most_one) { if (positive_map.contains(l.Index())) { builder.AddTerm(positive_map.at(l.Index()), IntegerValue(1)); diff --git a/ortools/sat/diffn.cc b/ortools/sat/diffn.cc index 07aa6279125..6c502469a44 100644 --- a/ortools/sat/diffn.cc +++ b/ortools/sat/diffn.cc @@ -32,6 +32,7 @@ #include "absl/log/vlog_is_on.h" #include "absl/numeric/bits.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/2d_distances_propagator.h" #include "ortools/sat/2d_mandatory_overlap_propagator.h" #include "ortools/sat/2d_orthogonal_packing.h" @@ -498,7 +499,7 @@ std::optional NonOverlappingRectanglesEnergyPropagator::FindConflict( std::vector active_box_ranges) { const auto rectangles_with_too_much_energy = - FindRectanglesWithEnergyConflictMC(active_box_ranges, *random_, 1.0, 0.8); + FindRectanglesWithEnergyConflictMC(active_box_ranges, random_, 1.0, 0.8); if (rectangles_with_too_much_energy.conflicts.empty() && rectangles_with_too_much_energy.candidates.empty()) { @@ -515,11 +516,11 @@ NonOverlappingRectanglesEnergyPropagator::FindConflict( absl::InlinedVector sampled_rectangles; std::sample(rectangles_with_too_much_energy.conflicts.begin(), rectangles_with_too_much_energy.conflicts.end(), - std::back_inserter(sampled_rectangles), 5, *random_); + std::back_inserter(sampled_rectangles), 5, random_); std::sample(rectangles_with_too_much_energy.candidates.begin(), rectangles_with_too_much_energy.candidates.end(), std::back_inserter(sampled_rectangles), - kSampleSize - sampled_rectangles.size(), *random_); + kSampleSize - sampled_rectangles.size(), random_); std::sort(sampled_rectangles.begin(), sampled_rectangles.end(), [](const Rectangle& a, const Rectangle& b) { const bool larger = std::make_pair(a.SizeX(), a.SizeY()) > @@ -928,8 +929,8 @@ bool NonOverlappingRectanglesDisjunctivePropagator:: if (!x_.ResetFromSubset(*x, boxes)) return false; // Collect the common overlapping coordinates of all boxes. - IntegerValue lb(std::numeric_limits::min()); - IntegerValue ub(std::numeric_limits::max()); + IntegerValue lb(kint64min); + IntegerValue ub(kint64max); for (const int b : boxes) { lb = std::max(lb, y->StartMax(b)); ub = std::min(ub, y->EndMin(b) - 1); diff --git a/ortools/sat/diffn.h b/ortools/sat/diffn.h index bb8d72ca269..7b52f06010b 100644 --- a/ortools/sat/diffn.h +++ b/ortools/sat/diffn.h @@ -44,9 +44,9 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { NonOverlappingRectanglesEnergyPropagator(NoOverlap2DConstraintHelper* helper, Model* model) : helper_(*helper), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), shared_stats_(model->GetOrCreate()), - orthogonal_packing_checker_(*random_, shared_stats_) {} + orthogonal_packing_checker_(random_, shared_stats_) {} ~NonOverlappingRectanglesEnergyPropagator() override; @@ -68,7 +68,7 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface { bool BuildAndReportEnergyTooLarge(absl::Span ranges); NoOverlap2DConstraintHelper& helper_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; SharedStatistics* shared_stats_; OrthogonalPackingInfeasibilityDetector orthogonal_packing_checker_; diff --git a/ortools/sat/diffn_util.cc b/ortools/sat/diffn_util.cc index 85a9a11add3..9a1cc7c1748 100644 --- a/ortools/sat/diffn_util.cc +++ b/ortools/sat/diffn_util.cc @@ -1596,15 +1596,185 @@ std::string RenderDot(std::optional bb, return ss.str(); } -std::vector FindEmptySpaces( - const Rectangle& bounding_box, std::vector ocupied_rectangles) { +std::vector FindEmptySpaces(const Rectangle& bounding_box, + std::vector rectangles) { // Sorting is not necessary for correctness but makes it faster. - std::sort(ocupied_rectangles.begin(), ocupied_rectangles.end(), + std::sort(rectangles.begin(), rectangles.end(), [](const Rectangle& a, const Rectangle& b) { return std::tuple(a.x_min, -a.x_max, a.y_min) < std::tuple(b.x_min, -b.x_max, b.y_min); }); - return PavedRegionDifference({bounding_box}, ocupied_rectangles); + return PavedRegionDifference({bounding_box}, rectangles); +} + +// We abstract the X and Y axes into "sweep" (the direction the line moves) +// and "cross" (the direction the active intervals span). +template +struct SweepTraits; + +// Horizontal: sweep line moves up Y-axis. Intervals span X-axis. +template <> +struct SweepTraits { + static IntegerValue sweep_min(const Rectangle& r) { return r.y_min; } + static IntegerValue sweep_max(const Rectangle& r) { return r.y_max; } + static IntegerValue cross_min(const Rectangle& r) { return r.x_min; } + static IntegerValue cross_max(const Rectangle& r) { return r.x_max; } + static Rectangle make_rect(IntegerValue c_min, IntegerValue c_max, + IntegerValue s_min, IntegerValue s_max) { + return {c_min, c_max, s_min, s_max}; + } +}; + +// Vertical: sweep line moves right X-axis. Intervals span Y-axis. +template <> +struct SweepTraits { + static IntegerValue sweep_min(const Rectangle& r) { return r.x_min; } + static IntegerValue sweep_max(const Rectangle& r) { return r.x_max; } + static IntegerValue cross_min(const Rectangle& r) { return r.y_min; } + static IntegerValue cross_max(const Rectangle& r) { return r.y_max; } + static Rectangle make_rect(IntegerValue c_min, IntegerValue c_max, + IntegerValue s_min, IntegerValue s_max) { + return {s_min, s_max, c_min, c_max}; + } +}; + +namespace { + +template +std::vector FindEmptySpacesImpl( + const Rectangle& bounding_box, absl::Span rectangles) { + using Traits = SweepTraits; + + struct Event { + IntegerValue sweep_coord; + IntegerValue cross_min; + IntegerValue cross_max; + int type; // 0 for TOP/RIGHT event, 1 for BOTTOM/LEFT event. + int rect_idx; + + bool operator<(const Event& other) const { + if (sweep_coord != other.sweep_coord) + return sweep_coord < other.sweep_coord; + if (type != other.type) return type < other.type; + // Geometric Tie-Breaker for adjacent rectangles. + if (cross_min != other.cross_min) return cross_min < other.cross_min; + return cross_max < other.cross_max; + } + }; + + struct EmptyInterval { + IntegerValue cross_min; + IntegerValue cross_max; + IntegerValue sweep_start; + int left_idx; + int right_idx; + + bool operator<(const EmptyInterval& other) const { + return cross_min < other.cross_min; + } + }; + + std::vector events; + events.reserve(rectangles.size() * 2); + + for (int i = 0; i < rectangles.size(); ++i) { + const auto& rect = rectangles[i]; + + const IntegerValue c_min = Traits::cross_min(rect); + const IntegerValue c_max = Traits::cross_max(rect); + const IntegerValue s_min = Traits::sweep_min(rect); + const IntegerValue s_max = Traits::sweep_max(rect); + + DCHECK_LT(c_min, c_max); + DCHECK_LT(s_min, s_max); + DCHECK_GE(c_min, Traits::cross_min(bounding_box)); + DCHECK_LE(c_max, Traits::cross_max(bounding_box)); + DCHECK_GE(s_min, Traits::sweep_min(bounding_box)); + DCHECK_LE(s_max, Traits::sweep_max(bounding_box)); + + events.push_back({s_min, c_min, c_max, 1, i}); + events.push_back({s_max, c_min, c_max, 0, i}); + } + + absl::c_sort(events); + + absl::btree_set active_intervals; + active_intervals.insert({Traits::cross_min(bounding_box), + Traits::cross_max(bounding_box), + Traits::sweep_min(bounding_box), -1, -1}); + + std::vector results; + + auto add_result = [&](IntegerValue c_min, IntegerValue c_max, + IntegerValue s_min, IntegerValue s_max, int l_idx, + int r_idx) { + if (c_min < c_max && s_min < s_max) { + results.push_back( + {Traits::make_rect(c_min, c_max, s_min, s_max), l_idx, r_idx}); + } + }; + + for (const Event& ev : events) { + if (ev.type == 0) { + const auto right_it = + active_intervals.lower_bound({ev.cross_max, 0, 0, 0, 0}); + + DCHECK(right_it != active_intervals.end()); + DCHECK(right_it != active_intervals.begin()); + + const auto left_it = std::prev(right_it); + + const EmptyInterval left_inv = *left_it; + const EmptyInterval right_inv = *right_it; + + active_intervals.erase(left_it, std::next(right_it)); + + add_result(left_inv.cross_min, left_inv.cross_max, left_inv.sweep_start, + ev.sweep_coord, left_inv.left_idx, left_inv.right_idx); + add_result(right_inv.cross_min, right_inv.cross_max, + right_inv.sweep_start, ev.sweep_coord, right_inv.left_idx, + right_inv.right_idx); + + active_intervals.insert({left_inv.cross_min, right_inv.cross_max, + ev.sweep_coord, left_inv.left_idx, + right_inv.right_idx}); + + } else { + auto it = active_intervals.upper_bound({ev.cross_min, 0, 0, 0, 0}); + if (it == active_intervals.begin()) return {}; + --it; + + const EmptyInterval inv = *it; + active_intervals.erase(it); + + add_result(inv.cross_min, inv.cross_max, inv.sweep_start, ev.sweep_coord, + inv.left_idx, inv.right_idx); + + active_intervals.insert({inv.cross_min, ev.cross_min, ev.sweep_coord, + inv.left_idx, ev.rect_idx}); + active_intervals.insert({ev.cross_max, inv.cross_max, ev.sweep_coord, + ev.rect_idx, inv.right_idx}); + } + } + + for (const EmptyInterval& inv : active_intervals) { + add_result(inv.cross_min, inv.cross_max, inv.sweep_start, + Traits::sweep_max(bounding_box), inv.left_idx, inv.right_idx); + } + + return results; +} + +} // namespace + +std::vector FindEmptySpacesHorizontally( + const Rectangle& bounding_box, absl::Span rectangles) { + return FindEmptySpacesImpl(bounding_box, rectangles); +} + +std::vector FindEmptySpacesVertically( + const Rectangle& bounding_box, absl::Span rectangles) { + return FindEmptySpacesImpl(bounding_box, rectangles); } std::vector PavedRegionDifference( @@ -1632,8 +1802,11 @@ std::vector PavedRegionDifference( // don't delete nodes that became stale, as explained in the class comment // below. struct BinaryTreeNode { - // Contains exactly one element if occupying_box_index != -1. - absl::flat_hash_set connected_components_descendants; + static constexpr int kEmpty = -1; // No box is currently occupying this node. + static constexpr int kMixed = -2; // More than one component is currently + // occupying this node. + int component_if_single = kEmpty; + // Hold the x_max of the box that is currently occupying this node (if any) to // know when it is stale. int occupying_box_x_max; @@ -1664,21 +1837,31 @@ struct SweepLineIntervalTree { // {self} + left.connected_components + right.connected_components. void RecomputeConnectedComponents(TreeNodeIndex idx) { BinaryTreeNode& node = tree_nodes[idx]; - if (node.occupying_box_index != -1) { - node.connected_components_descendants = { - union_find.FindRoot(node.occupying_box_index)}; + if (node.occupying_box_index != BinaryTreeNode::kEmpty) { + node.component_if_single = union_find.FindRoot(node.occupying_box_index); return; } - node.connected_components_descendants.clear(); - if (tree.IsLeaf(idx)) return; - for (const TreeNodeIndex child_idx : - {tree.LeftChild(idx), tree.RightChild(idx)}) { - // The order is non-deterministic, but since this is doing the union of - // hash sets the result is deterministic. - for (const int c : - tree_nodes[child_idx].connected_components_descendants) { - node.connected_components_descendants.insert(union_find.FindRoot(c)); - } + + if (tree.IsLeaf(idx)) { + node.component_if_single = BinaryTreeNode::kEmpty; + return; + } + + int left_comp = tree_nodes[tree.LeftChild(idx)].component_if_single; + int right_comp = tree_nodes[tree.RightChild(idx)].component_if_single; + + // Resolve roots in case they were merged elsewhere + if (left_comp >= 0) left_comp = union_find.FindRoot(left_comp); + if (right_comp >= 0) right_comp = union_find.FindRoot(right_comp); + + if (left_comp == BinaryTreeNode::kEmpty) { + node.component_if_single = right_comp; + } else if (right_comp == BinaryTreeNode::kEmpty) { + node.component_if_single = left_comp; + } else if (left_comp == right_comp) { + node.component_if_single = left_comp; // Both share the same component + } else { + node.component_if_single = BinaryTreeNode::kMixed; } } @@ -1716,14 +1899,14 @@ struct SweepLineIntervalTree { // branch to the new box. continue; } - const bool had_different_component = - absl::c_any_of(child_node.connected_components_descendants, - [this, component_index](const int c) { - return !union_find.Connected(c, component_index); - }); + // Since everything is intersecting the current box, all descendants // must be in one single component. - child_node.connected_components_descendants = {component_index}; + const int previous_component = child_node.component_if_single; + child_node.component_if_single = component_index; + if (previous_component == BinaryTreeNode::kEmpty) { + continue; // No need to recurse. + } // Only go down on the tree if we have below either: // - a different component to connect. @@ -1733,7 +1916,8 @@ struct SweepLineIntervalTree { // delete. Since a box can only be deleted log N times (one per interval // it was cut into) and we can only connect O(N) components in total, the // amortized cost of a call to UpdateChildrenIntersecting is O((log N)^2). - if (had_different_component) { + if (previous_component == BinaryTreeNode::kMixed || + !union_find.Connected(previous_component, component_index)) { UpdateChildrenIntersecting(child_idx, sweep_line_x_pos, component_index, new_connections); } @@ -1743,18 +1927,42 @@ struct SweepLineIntervalTree { bool UpdateParents(TreeNodeIndex node, int sweep_line_x_pos, int component_index, std::vector* new_connections) { if (node == tree.Root()) return false; + + const int current_box_root = union_find.FindRoot(component_index); + for (TreeNodeIndex parent = tree.Parent(node); parent != tree.Root(); parent = tree.Parent(parent)) { RemoveNodeIfXMaxLowerOrEqual(parent, sweep_line_x_pos); BinaryTreeNode& parent_value = tree_nodes[parent]; if (parent_value.occupying_box_index != -1) { if (union_find.AddEdge(parent_value.occupying_box_index, - component_index)) { + current_box_root)) { new_connections->push_back(parent_value.occupying_box_index); - return true; + return true; // Short-circuit: connected to an ancestor, covers + // everything above + } + } + + // Update parent_value.component_if_single. + if (parent_value.component_if_single == BinaryTreeNode::kEmpty) { + // Node was empty, now holds our component + parent_value.component_if_single = current_box_root; + } else if (parent_value.component_if_single >= 0) { + // Node was homogenous. Are they STILL in the same component? + // We must check FindRoot() because they might have been merged + // elsewhere. + const int existing_root = + union_find.FindRoot(parent_value.component_if_single); + + if (existing_root != current_box_root) { + // Distinct active components exist in this subtree. It is now mixed. + parent_value.component_if_single = BinaryTreeNode::kMixed; + } else { + // They belong to the same component, keep it updated to the root + parent_value.component_if_single = existing_root; } } - parent_value.connected_components_descendants.insert(component_index); + // If it is already kMixed, it stays kMixed. No action needed. } return false; } @@ -1768,13 +1976,15 @@ struct SweepLineIntervalTree { RemoveNodeIfXMaxLowerOrEqual(idx, sweep_line_x_pos); int cur_box_component = union_find.FindRoot(box_index); BinaryTreeNode& node = tree_nodes[idx]; + if (node.occupying_box_index == -1) { - node.connected_components_descendants = {box_index}; + node.component_if_single = cur_box_component; node.occupying_box_index = box_index; node.occupying_box_x_max = x_max; + const bool had_occupied_parent = UpdateParents( idx, sweep_line_x_pos, cur_box_component, new_connections); - // We can only be connecting children if it is not already connect via + // We can only be connecting children if it is not already connected via // something above on the tree. if (!had_occupied_parent) { UpdateChildrenIntersecting(idx, sweep_line_x_pos, cur_box_component, @@ -1786,9 +1996,12 @@ struct SweepLineIntervalTree { new_connections->push_back(node.occupying_box_index); cur_box_component = union_find.FindRoot(cur_box_component); } - node.connected_components_descendants = {cur_box_component}; + + // The entire subtree is merged into this single component. + node.component_if_single = cur_box_component; + if (node.occupying_box_x_max < x_max) { - // Replace the existing box by the new one. + // Replace the existing box by the new one, since it lasts longer. node.occupying_box_index = box_index; node.occupying_box_x_max = x_max; } @@ -1806,6 +2019,11 @@ struct Rectangle32 { int y_min; int y_max; int index; + + bool IsInsideOf(const Rectangle32& other) const { + return x_min >= other.x_min && x_max <= other.x_max && + y_min >= other.y_min && y_max <= other.y_max; + } }; // Requires that rectangles are sorted by x_min and that sizes on both @@ -1856,56 +2074,63 @@ struct PostProcessedResult { std::pair bounding_box; // Always starting at (0,0). }; +// This function is a preprocessing function for algorithms that find overlap +// between rectangles. It does the following: +// - It converts the arbitrary int64_t coordinates into a small integer by +// sorting the possible values and assigning them consecutive integers. +// - It grows zero size intervals to make them size one. This simplifies +// things considerably, since it is hard to reason about degenerated +// rectangles in the general algorithm. +// +// Note that the last point need to be done with care. Imagine the following +// example: +// +----------+ +// | | +// | +--------------+ +// | | | +// | | p,q r | +// | +----*-----*-+-+ +// | | | +// | | | +// | | | +// | +------------+ +// | | +// | | +// +----------+ +// Where p,q and r are points (ie, boxes of size 0x0) and p and q have the +// same coordinates. We replace them by the following: +// +----------+ +// | | +// | +----------------------+ +// | | | +// | | | +// | +----+-+---------------+ +// | | |p| +// | | +-+-+ +// | | |q| +// | | +-+ +-+ +// | | |r| +// | +--------------+-+---+ +// | | | +// | | | +// | | | +// | +--------------------+ +// | | +// | | +// +----------+ +// +// That is a pretty radical deformation of the original shape, but it retains +// the property of whether a pair of rectangles intersect or not. +// +// If this function detects a pair (r1, r2) so that r1 is fully inside r2, +// it adds the pair to the `fully_contained_pairs` vector and omits r1 from +// the returned vector. Note that this is a best-effort heuristic and will not +// detect all fully contained pairs. Moreover, it will return each rectangle +// that is fully contained by another rectangle only once, even if it is +// contained by multiple rectangles. PostProcessedResult ConvertToRectangle32WithNonZeroSizes( - absl::Span rectangles) { - // This function is a preprocessing function for algorithms that find overlap - // between rectangles. It does the following: - // - It converts the arbitrary int64_t coordinates into a small integer by - // sorting the possible values and assigning them consecutive integers. - // - It grows zero size intervals to make them size one. This simplifies - // things considerably, since it is hard to reason about degenerated - // rectangles in the general algorithm. - // - // Note that the last point need to be done with care. Imagine the following - // example: - // +----------+ - // | | - // | +--------------+ - // | | | - // | | p,q r | - // | +----*-----*-+-+ - // | | | - // | | | - // | | | - // | +------------+ - // | | - // | | - // +----------+ - // Where p,q and r are points (ie, boxes of size 0x0) and p and q have the - // same coordinates. We replace them by the following: - // +----------+ - // | | - // | +----------------------+ - // | | | - // | | | - // | +----+-+---------------+ - // | | |p| - // | | +-+-+ - // | | |q| - // | | +-+ +-+ - // | | |r| - // | +--------------+-+---+ - // | | | - // | | | - // | | | - // | +--------------------+ - // | | - // | | - // +----------+ - // - // That is a pretty radical deformation of the original shape, but it retains - // the property of whether a pair of rectangles intersect or not. - + absl::Span rectangles, + std::vector>* fully_contained_pairs) { if (rectangles.empty()) return {}; enum class Event { @@ -1997,17 +2222,64 @@ PostProcessedResult ConvertToRectangle32WithNonZeroSizes( std::vector sorted_rectangles32; sorted_rectangles32.reserve(rectangles.size()); + + Rectangle32 largest_box; + for (const auto [x, event, index] : x_events) { if (event == Event::kBegin || event == Event::kPoint) { - sorted_rectangles32.push_back(rectangles32[index]); + const Rectangle32& rectangle = rectangles32[index]; + + if (sorted_rectangles32.empty()) { + // First element initialization + largest_box = rectangle; + } else { + const Rectangle32& prev = sorted_rectangles32.back(); + + // Check if the current box is fully inside the previous box or the + // largest box seen so far. If so, add it to the output and skip it + // completely. + if (rectangle.IsInsideOf(prev)) { + fully_contained_pairs->push_back({rectangle.index, prev.index}); + continue; + } + + if (largest_box.index != prev.index && + rectangle.IsInsideOf(largest_box)) { + fully_contained_pairs->push_back( + {rectangle.index, largest_box.index}); + continue; + } + + // Update the largest box. + const int current_size = std::min(rectangle.y_max - rectangle.y_min, + rectangle.x_max - rectangle.x_min); + const int remaining_x = largest_box.x_max - rectangle.x_min; + + int largest_remaining_size = -1; + if (remaining_x > 0) { + largest_remaining_size = + std::min(largest_box.y_max - largest_box.y_min, remaining_x); + } + + if (current_size > largest_remaining_size) { + largest_box = rectangle; + } + } + + sorted_rectangles32.push_back(rectangle); } } return {sorted_rectangles32, {max_x_index, max_y_index}}; } + +// If we happen to prove as a side-effect that some rectangles are disjoint from +// all the other rectangles, we add their indices to the +// `proven_disjoint_indices` vector. template std::optional> FindOneIntersectionIfPresentImpl( - absl::Span rectangles) { + absl::Span rectangles, + std::vector* proven_disjoint_indices = nullptr) { using CoordinateType = std::decay_t; DCHECK(absl::c_is_sorted(rectangles, [](const RectangleT& a, const RectangleT& b) { @@ -2041,6 +2313,9 @@ std::optional> FindOneIntersectionIfPresentImpl( if (!inserted) { if (rectangles[it->index].x_max <= x) { // We just replace if the rectangle at position i is stale. + if (proven_disjoint_indices) { + proven_disjoint_indices->push_back(it->index); + } it->index = i; } else { // Intersection. @@ -2055,6 +2330,9 @@ std::optional> FindOneIntersectionIfPresentImpl( // Lazy erase stale entry. if (rectangles[it_before->index].x_max <= x) { + if (proven_disjoint_indices) { + proven_disjoint_indices->push_back(it_before->index); + } // For absl::btree_set we don't have iterator stability, so we do need // to re-assign 'it' to the element just after the one we erased. it = interval_set.erase(it_before); @@ -2076,6 +2354,9 @@ std::optional> FindOneIntersectionIfPresentImpl( while (it != interval_set.end()) { // Lazy erase stale entry. if (rectangles[it->index].x_max <= x) { + if (proven_disjoint_indices) { + proven_disjoint_indices->push_back(it->index); + } it = interval_set.erase(it); continue; } @@ -2096,10 +2377,45 @@ std::optional> FindOneIntersectionIfPresentImpl( std::vector> FindPartialRectangleIntersections( absl::Span rectangles) { - auto postprocessed = ConvertToRectangle32WithNonZeroSizes(rectangles); - return FindPartialRectangleIntersectionsImpl( - postprocessed.rectangles_sorted_by_x_min, - postprocessed.bounding_box.second); + std::vector> extra_arcs; + auto postprocessed = + ConvertToRectangle32WithNonZeroSizes(rectangles, &extra_arcs); + + if (extra_arcs.empty()) { + std::vector proven_disjoint_indices; + // We didn't find any trivial intersection, let's check quickly if there + // is any intersection at all. + if (!FindOneIntersectionIfPresentImpl( + absl::MakeConstSpan(postprocessed.rectangles_sorted_by_x_min), + &proven_disjoint_indices) + .has_value()) { + return {}; + } + if (!proven_disjoint_indices.empty()) { + // Remove the proven disjoint rectangles from downstream processing. + std::vector is_disjoint( + postprocessed.rectangles_sorted_by_x_min.size(), false); + for (const int index : proven_disjoint_indices) { + is_disjoint[index] = true; + } + int new_rectangles_size = 0; + for (int i = 0; i < postprocessed.rectangles_sorted_by_x_min.size(); + ++i) { + if (!is_disjoint[i]) { + postprocessed.rectangles_sorted_by_x_min[new_rectangles_size++] = + postprocessed.rectangles_sorted_by_x_min[i]; + } + } + postprocessed.rectangles_sorted_by_x_min.resize(new_rectangles_size); + } + } + std::vector> intersections = + FindPartialRectangleIntersectionsImpl( + postprocessed.rectangles_sorted_by_x_min, + postprocessed.bounding_box.second); + intersections.insert(intersections.end(), extra_arcs.begin(), + extra_arcs.end()); + return intersections; } std::optional> FindOneIntersectionIfPresent( @@ -2109,7 +2425,12 @@ std::optional> FindOneIntersectionIfPresent( std::optional> FindOneIntersectionIfPresentWithZeroArea( absl::Span rectangles) { - auto postprocessed = ConvertToRectangle32WithNonZeroSizes(rectangles); + std::vector> extra_arcs; + auto postprocessed = + ConvertToRectangle32WithNonZeroSizes(rectangles, &extra_arcs); + if (!extra_arcs.empty()) { + return extra_arcs[0]; + } std::optional> result = FindOneIntersectionIfPresentImpl( absl::MakeConstSpan(postprocessed.rectangles_sorted_by_x_min)); if (!result.has_value()) return {}; diff --git a/ortools/sat/diffn_util.h b/ortools/sat/diffn_util.h index 110f8f59fd5..f4af1d82f7c 100644 --- a/ortools/sat/diffn_util.h +++ b/ortools/sat/diffn_util.h @@ -685,9 +685,58 @@ std::string RenderDot(std::optional bb, // Given a bounding box and a list of rectangles inside that bounding box, // returns a list of rectangles partitioning the empty area inside the bounding -// box. -std::vector FindEmptySpaces( - const Rectangle& bounding_box, std::vector ocupied_rectangles); +// box. Contrary to FindEmptySpacesHorizontally() and +// FindEmptySpacesVertically() below, this function supports overlapping input +// rectangles, but runs in O(N^2) time. +std::vector FindEmptySpaces(const Rectangle& bounding_box, + std::vector rectangles); + +// Given a bounding box and a list of non-overlapping rectangles inside that +// bounding box, returns a list of non-overlapping rectangles partitioning the +// empty area inside the bounding box. Moreover each rectangle has both its left +// and right edge fully touching either one single box or the bounding box. Note +// that the input rectangles must have non-zero area. +// +// Example: +// Input +// +------------------------------+ +// | | +// | | +// | 111111111 | +// | 111111111 | +// | | +// | 0000000 | +// | 0000000 | +// | | +// +------------------------------+ +// +// Output +// +------------------------------+ +// |******************************| +// |******************************| +// |...................111111111$$| +// |...................111111111$$| +// |""""""""""""""""""""""""""""""| +// |@@0000000*********************| +// |@@0000000*********************| +// |..............................| +// +------------------------------+ +struct EmptySpace { + Rectangle rect; + // Index of the rectangle touching the left/bottom edge, or -1 for bounding + // box. + int before_idx; + // Index of the rectangle touching the right/top edge, or -1 for bounding + // box. + int after_idx; +}; +std::vector FindEmptySpacesHorizontally( + const Rectangle& bounding_box, absl::Span rectangles); + +// Same as FindEmptySpacesHorizontally(), but returns vertical strips +// partitioning the empty space instead of horizontal ones. +std::vector FindEmptySpacesVertically( + const Rectangle& bounding_box, absl::Span rectangles); // Given two regions, each one of them defined by a vector of non-overlapping // rectangles paving them, returns a vector of non-overlapping rectangles that diff --git a/ortools/sat/diffn_util_test.cc b/ortools/sat/diffn_util_test.cc index 6d7cde88466..8da1dd61fa8 100644 --- a/ortools/sat/diffn_util_test.cc +++ b/ortools/sat/diffn_util_test.cc @@ -36,6 +36,7 @@ #include "benchmark/benchmark.h" #include "gtest/gtest.h" #include "ortools/base/gmock.h" +#include "ortools/base/types.h" #include "ortools/graph_base/connected_components.h" #include "ortools/graph_base/strongly_connected_components.h" #include "ortools/sat/2d_orthogonal_packing_testing.h" @@ -267,7 +268,7 @@ std::vector GenerateRandomIntervalVector( std::vector intervals; intervals.reserve(num_intervals); const int64_t interval_domain = - absl::LogUniform(random, 1, std::numeric_limits::max()); + absl::LogUniform(random, 1, kint64max); const int64_t max_interval_length = absl::Uniform( random, std::max(1, interval_domain / (2 * num_intervals + 1)), interval_domain); @@ -1104,6 +1105,95 @@ TEST(FindPartialIntersections, Random) { } } +TEST(FindEmptySpaces, Random) { + absl::BitGen random; + constexpr int num_runs = 100; + + const Rectangle bounding_box = { + .x_min = 0, .x_max = 100, .y_min = 0, .y_max = 100}; + const IntegerValue total_bb_area = bounding_box.Area(); + + for (int k = 0; k < num_runs; k++) { + std::vector rectangles = + GenerateNonConflictingRectanglesWithPacking({100, 100}, 60, random); + + IntegerValue occupied_area = 0; + for (const Rectangle& r : rectangles) { + occupied_area += r.Area(); + } + + for (const bool test_vertical : {false, true}) { + std::vector empty_spaces = + test_vertical ? FindEmptySpacesVertically(bounding_box, rectangles) + : FindEmptySpacesHorizontally(bounding_box, rectangles); + + std::vector all_rects = rectangles; + IntegerValue empty_area = 0; + + for (const EmptySpace& space : empty_spaces) { + all_rects.push_back(space.rect); + empty_area += space.rect.Area(); + + // Check that the returned rectangle is touching the adjacent inputs + if (test_vertical) { + if (space.before_idx == -1) { + EXPECT_EQ(space.rect.y_min, bounding_box.y_min); + } else { + ASSERT_GE(space.before_idx, 0); + ASSERT_LT(space.before_idx, rectangles.size()); + EXPECT_EQ(space.rect.y_min, rectangles[space.before_idx].y_max); + EXPECT_GE(space.rect.x_min, rectangles[space.before_idx].x_min); + EXPECT_LE(space.rect.x_max, rectangles[space.before_idx].x_max); + } + + if (space.after_idx == -1) { + EXPECT_EQ(space.rect.y_max, bounding_box.y_max); + } else { + ASSERT_GE(space.after_idx, 0); + ASSERT_LT(space.after_idx, rectangles.size()); + EXPECT_EQ(space.rect.y_max, rectangles[space.after_idx].y_min); + EXPECT_GE(space.rect.x_min, rectangles[space.after_idx].x_min); + EXPECT_LE(space.rect.x_max, rectangles[space.after_idx].x_max); + } + } else { + if (space.before_idx == -1) { + EXPECT_EQ(space.rect.x_min, bounding_box.x_min); + } else { + ASSERT_GE(space.before_idx, 0); + ASSERT_LT(space.before_idx, rectangles.size()); + EXPECT_EQ(space.rect.x_min, rectangles[space.before_idx].x_max); + EXPECT_GE(space.rect.y_min, rectangles[space.before_idx].y_min); + EXPECT_LE(space.rect.y_max, rectangles[space.before_idx].y_max); + } + + if (space.after_idx == -1) { + EXPECT_EQ(space.rect.x_max, bounding_box.x_max); + } else { + ASSERT_GE(space.after_idx, 0); + ASSERT_LT(space.after_idx, rectangles.size()); + EXPECT_EQ(space.rect.x_max, rectangles[space.after_idx].x_min); + EXPECT_GE(space.rect.y_min, rectangles[space.after_idx].y_min); + EXPECT_LE(space.rect.y_max, rectangles[space.after_idx].y_max); + } + } + } + + // Check that the result is is not overlapping neither avec itself nor + // with the input. + absl::c_sort(all_rects, [](const Rectangle& a, const Rectangle& b) { + return a.x_min < b.x_min; + }); + EXPECT_FALSE(FindOneIntersectionIfPresent(all_rects).has_value()) + << "Overlap detected between empty spaces and/or occupied " + "rectangles!"; + + // Check that the all bounding box area is covered. + EXPECT_EQ(occupied_area + empty_area, total_bb_area) + << "Sum of areas does not match bounding box."; + } + } +} + void BM_FindRectangles(benchmark::State& state) { absl::BitGen random; std::vector> problems; diff --git a/ortools/sat/diophantine.cc b/ortools/sat/diophantine.cc index dcc86ff39f9..1cff6cf2196 100644 --- a/ortools/sat/diophantine.cc +++ b/ortools/sat/diophantine.cc @@ -24,6 +24,7 @@ #include "absl/numeric/int128.h" #include "absl/types/span.h" #include "ortools/base/mathutil.h" +#include "ortools/base/types.h" #include "ortools/sat/util.h" namespace operations_research::sat { @@ -31,10 +32,10 @@ namespace operations_research::sat { namespace { int64_t Gcd(const absl::Span coeffs) { - DCHECK(coeffs[0] != std::numeric_limits::min()); + DCHECK(coeffs[0] != kint64min); int64_t gcd = std::abs(coeffs[0]); for (int i = 1; i < coeffs.size(); ++i) { - DCHECK(coeffs[i] != std::numeric_limits::min()); + DCHECK(coeffs[i] != kint64min); const int64_t abs_coeff = std::abs(coeffs[i]); gcd = std::gcd(gcd, abs_coeff); } @@ -69,12 +70,12 @@ void ReduceModuloBasis(absl::Span> basis, std::vector GreedyFastDecreasingGcd( const absl::Span coeffs) { std::vector result; - DCHECK(coeffs[0] != std::numeric_limits::min()); + DCHECK(coeffs[0] != kint64min); int64_t min_abs_coeff = std::abs(coeffs[0]); int min_term = 0; int64_t global_gcd = min_abs_coeff; for (int i = 1; i < coeffs.size(); ++i) { - DCHECK(coeffs[i] != std::numeric_limits::min()); + DCHECK(coeffs[i] != kint64min); const int64_t abs_coeff = std::abs(coeffs[i]); global_gcd = std::gcd(global_gcd, abs_coeff); if (abs_coeff < min_abs_coeff) { diff --git a/ortools/sat/diophantine_test.cc b/ortools/sat/diophantine_test.cc index d198aceb657..b09ae2e46a6 100644 --- a/ortools/sat/diophantine_test.cc +++ b/ortools/sat/diophantine_test.cc @@ -29,6 +29,7 @@ #include "gtest/gtest.h" #include "ortools/base/log_severity.h" #include "ortools/base/mathutil.h" +#include "ortools/base/types.h" namespace operations_research::sat { @@ -105,8 +106,7 @@ TEST_P(RandomTest, GreedyFastDecreasingGcd) { const int num_elements = absl::Uniform(random, 1, 50); std::vector coeffs(num_elements); for (int i = 0; i < num_elements; ++i) { - coeffs[i] = UniformNonZero(random, 1 + std::numeric_limits::min(), - std::numeric_limits::max()); + coeffs[i] = UniformNonZero(random, 1 + kint64min, kint64max); } const std::vector order = GreedyFastDecreasingGcd(coeffs); if (order.empty()) { diff --git a/ortools/sat/disjunctive.cc b/ortools/sat/disjunctive.cc index 4ce86b3980c..3b2815f24e7 100644 --- a/ortools/sat/disjunctive.cc +++ b/ortools/sat/disjunctive.cc @@ -18,7 +18,6 @@ #include #include -#include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -338,7 +337,9 @@ template void CombinedDisjunctive::AddNoOverlap( absl::Span vars) { const int index = task_sets_.size(); - task_sets_.emplace_back(task_set_storage_.emplace_back()); + auto& storage = task_set_storage_.emplace_back(); + storage.ClearAndReserve(vars.size()); + task_sets_.emplace_back(storage); end_mins_.push_back(kMinIntegerValue); for (const IntervalVariable var : vars) { task_to_disjunctives_[var.value()].push_back(index); @@ -466,6 +467,7 @@ bool CombinedDisjunctive::Propagate() { } bool DisjunctiveOverloadChecker::Propagate() { + if (helper_->FixedSearchFirstSolutionMode()) return true; if (!helper_->IsEnforced()) return true; stats_.OnPropagate(); if (!helper_->SynchronizeAndSetTimeDirection(/*is_forward=*/true)) { @@ -928,6 +930,7 @@ bool DisjunctiveSimplePrecedences::PropagateOneDirection() { } bool DisjunctiveDetectablePrecedences::Propagate() { + if (helper_->FixedSearchFirstSolutionMode()) return true; if (!helper_->IsEnforced()) return true; stats_.OnPropagate(); if (!helper_->SynchronizeAndSetTimeDirection(time_direction_)) { @@ -1211,6 +1214,7 @@ int DisjunctiveDetectablePrecedences::RegisterWith( } bool DisjunctivePrecedences::Propagate() { + if (helper_->FixedSearchFirstSolutionMode()) return true; if (!helper_->IsEnforced()) return true; stats_.OnPropagate(); if (!helper_->SynchronizeAndSetTimeDirection(time_direction_)) { @@ -1425,6 +1429,7 @@ int DisjunctivePrecedences::RegisterWith(GenericLiteralWatcher* watcher) { } bool DisjunctiveNotLast::Propagate() { + if (helper_->FixedSearchFirstSolutionMode()) return true; if (!helper_->IsEnforced()) return true; stats_.OnPropagate(); if (!helper_->SynchronizeAndSetTimeDirection(time_direction_)) { @@ -1643,6 +1648,7 @@ int DisjunctiveNotLast::RegisterWith(GenericLiteralWatcher* watcher) { } bool DisjunctiveEdgeFinding::Propagate() { + if (helper_->FixedSearchFirstSolutionMode()) return true; if (!helper_->IsEnforced()) return true; stats_.OnPropagate(); const int num_tasks = helper_->NumTasks(); diff --git a/ortools/sat/disjunctive.h b/ortools/sat/disjunctive.h index eaf588e4d8e..5495bfb3187 100644 --- a/ortools/sat/disjunctive.h +++ b/ortools/sat/disjunctive.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -34,7 +35,6 @@ #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" #include "ortools/util/scheduling.h" -#include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" namespace operations_research { @@ -273,7 +273,9 @@ class CombinedDisjunctive : public PropagatorInterface { std::vector> task_to_disjunctives_; std::vector task_is_added_; std::vector task_sets_; - std::vector> task_set_storage_; + + // Using a deque to avoid invalidating references. + std::deque> task_set_storage_; std::vector end_mins_; }; diff --git a/ortools/sat/docs/integer_arithmetic.md b/ortools/sat/docs/integer_arithmetic.md index 221cac23f17..4fac6d79368 100644 --- a/ortools/sat/docs/integer_arithmetic.md +++ b/ortools/sat/docs/integer_arithmetic.md @@ -851,14 +851,14 @@ def step_function_sample_sat(): # expr == 0 on [5, 6] U [8, 10] b0 = model.new_bool_var('b0') model.add_linear_expression_in_domain( - x, cp_model.Domain.from_intervals([(5, 6), (8, 10)]) + x, cp_model.Domain.from_intervals([[5, 6], [8, 10]]) ).only_enforce_if(b0) model.add(expr == 0).only_enforce_if(b0) # expr == 2 on [0, 1] U [3, 4] U [11, 20] b2 = model.new_bool_var('b2') model.add_linear_expression_in_domain( - x, cp_model.Domain.from_intervals([(0, 1), (3, 4), (11, 20)]) + x, cp_model.Domain.from_intervals([[0, 1], [3, 4], [11, 20]]) ).only_enforce_if(b2) model.add(expr == 2).only_enforce_if(b2) diff --git a/ortools/sat/docs/scheduling.md b/ortools/sat/docs/scheduling.md index 60ca9279738..ad1c7a1b0b9 100644 --- a/ortools/sat/docs/scheduling.md +++ b/ortools/sat/docs/scheduling.md @@ -2310,7 +2310,7 @@ def scheduling_with_calendar_sample_sat(): # Because of the break, work cannot start at 13h. start = model.new_int_var_from_domain( - cp_model.Domain.from_intervals([(8, 12), (14, 15)]), 'start' + cp_model.Domain.from_intervals([[8, 12], [14, 15]]), 'start' ) duration = model.new_int_var(3, 4, 'duration') end = model.new_int_var(8, 18, 'end') diff --git a/ortools/sat/drat_checker.cc b/ortools/sat/drat_checker.cc index 174e32ed997..e15c56d67de 100644 --- a/ortools/sat/drat_checker.cc +++ b/ortools/sat/drat_checker.cc @@ -19,7 +19,6 @@ #include #include #include // NOLINT -#include #include #include @@ -32,6 +31,7 @@ #include "absl/types/span.h" #include "ortools/base/hash.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/sat_base.h" #include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" @@ -135,7 +135,7 @@ void DratChecker::DeleteClause(absl::Span clause) { Clause& existing_clause = clauses_[*it]; existing_clause.num_copies -= 1; if (existing_clause.num_copies == 0) { - DCHECK(existing_clause.deleted_index == std::numeric_limits::max()); + DCHECK(existing_clause.deleted_index == kint32max); existing_clause.deleted_index = clauses_.size() - 1; if (clauses_.back().num_literals >= 2) { clauses_[ClauseIndex(clauses_.size() - 2)].deleted_clauses.push_back( @@ -278,7 +278,7 @@ void DratChecker::Init() { if (clause.num_literals >= 2) { // Don't watch the literals of the deleted clauses right away, instead // watch them when these clauses become 'undeleted' in backward checking. - if (clause.deleted_index == std::numeric_limits::max()) { + if (clause.deleted_index == kint32max) { WatchClause(clause_index); } } else if (clause.num_literals == 1) { diff --git a/ortools/sat/drat_checker.h b/ortools/sat/drat_checker.h index d666e79639f..a88c6c230a7 100644 --- a/ortools/sat/drat_checker.h +++ b/ortools/sat/drat_checker.h @@ -25,6 +25,7 @@ #include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/sat_base.h" #include "ortools/util/strong_integers.h" @@ -122,7 +123,7 @@ class DratChecker { // DeleteClause(c0), AddProblemClause(c2), ... if c0's index is 0, then its // deleted_index is 2. Meaning that when checking a clause whose index is // larger than or equal to 2 (e.g. c2), c0 can be ignored. - ClauseIndex deleted_index = ClauseIndex(std::numeric_limits::max()); + ClauseIndex deleted_index = ClauseIndex(kint32max); // The indices of the clauses (with at least two literals) which are deleted // just after this clause. diff --git a/ortools/sat/encoding.cc b/ortools/sat/encoding.cc index 0a3133f5ecb..8b7af8e2185 100644 --- a/ortools/sat/encoding.cc +++ b/ortools/sat/encoding.cc @@ -29,6 +29,7 @@ #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/sat/pb_constraint.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -196,7 +197,7 @@ Coefficient EncodingNode::Reduce(const SatSolver& solver) { void EncodingNode::ApplyWeightUpperBound(Coefficient gap, SatSolver* solver) { CHECK_GT(weight_, 0); const Coefficient num_allowed = (gap / weight_); - if (num_allowed > std::numeric_limits::max() / 2) return; + if (num_allowed > kint32max / 2) return; const int new_size = std::max(0, (weight_lb_ - lb_) + static_cast(num_allowed.value())); if (size() <= new_size) return; @@ -703,7 +704,7 @@ bool ObjectiveEncoder::ProcessCore(absl::Span core, gtl::STLSortAndRemoveDuplicates(&adj); } const std::vector> index_decompo = - AtMostOneDecomposition(graph, *random_, &buffer); + AtMostOneDecomposition(graph, random_, &buffer); // Convert. std::vector new_order; diff --git a/ortools/sat/encoding.h b/ortools/sat/encoding.h index bbd3fecf535..fe3cd0b2f0b 100644 --- a/ortools/sat/encoding.h +++ b/ortools/sat/encoding.h @@ -233,7 +233,7 @@ class ObjectiveEncoder { : params_(*model->GetOrCreate()), sat_solver_(model->GetOrCreate()), implications_(model->GetOrCreate()), - random_(model->GetOrCreate()) {} + random_(*model->GetOrCreate()) {} // Updates the encoding using the given core. The literals in the core must // match the order in nodes. Returns false if the model become infeasible. @@ -271,7 +271,7 @@ class ObjectiveEncoder { const SatParameters& params_; SatSolver* sat_solver_; BinaryImplicationGraph* implications_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; }; } // namespace sat diff --git a/ortools/sat/feasibility_jump.cc b/ortools/sat/feasibility_jump.cc index 990f482ff1d..f180edce9c5 100644 --- a/ortools/sat/feasibility_jump.cc +++ b/ortools/sat/feasibility_jump.cc @@ -27,6 +27,7 @@ #include #include +#include "absl/flags/flag.h" #include "absl/functional/any_invocable.h" #include "absl/functional/bind_front.h" #include "absl/functional/function_ref.h" @@ -38,11 +39,14 @@ #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/algorithms/binary_search.h" +#include "ortools/base/log_severity.h" +#include "ortools/base/types.h" #include "ortools/sat/combine_solutions.h" #include "ortools/sat/constraint_violation.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" #include "ortools/sat/cp_model_copy.h" +#include "ortools/sat/cp_model_utils.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/linear_model.h" #include "ortools/sat/sat_parameters.pb.h" @@ -402,10 +406,10 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { reset_weights = true; if (state_->options.use_restart) { states_->CollectStatistics(*state_); - state_->options.Randomize(params_, &random_); + state_->options.Randomize(params_, random_); state_->counters = LsCounters(); // Reset. } else { - state_->options.Randomize(params_, &random_); + state_->options.Randomize(params_, random_); } if (type() == SubSolver::INCOMPLETE) { // This is not used once we have a solution, and setting it to false @@ -549,6 +553,32 @@ std::function FeasibilityJumpSolver::GenerateTask(int64_t /*task_id*/) { } else { shared_response_->LogMessage(name(), "infeasible solution. Aborting."); model_is_supported_ = false; + if (DEBUG_MODE) { + CpSolverResponse response; + response.set_solution_info( + absl::StrCat(name(), "_", state_->options.name())); + evaluator_->ComputeAllNonLinearViolations(state_->solution); + for (int c = 0; c < evaluator_->NumEvaluatorConstraints(); ++c) { + if (!evaluator_->IsViolated(c)) continue; + LOG(INFO) << "Constraint " << c << " is violated with weight " + << state_->weights[c] << " and compound weight " + << state_->compound_weights[c] << "."; + LOG(INFO) << "Constraint " << c << " is " + << evaluator_->ConstraintDebugString(c); + } + response.mutable_solution()->Assign(state_->solution.begin(), + state_->solution.end()); + const std::string file = + absl::StrCat(absl::GetFlag(FLAGS_cp_model_dump_prefix), + "wrong_response.pb.txt"); + LOG(INFO) << "Dumping infeasible response proto to '" << file << "'."; + CHECK(WriteModelProtoToFile(response, file)); + + // Crash. + LOG(FATAL) << "Infeasible LS solution!" + << " source: '" << response.solution_info() << "'" + << " dumped CpSolverResponse to '" << file << "'."; + } } } @@ -581,7 +611,7 @@ double FeasibilityJumpSolver::ComputeScore(absl::Span weights, ++state_->counters.num_scores_computed; double score = evaluator_->WeightedViolationDelta( linear_only, weights, var, delta, absl::MakeSpan(state_->solution)); - constexpr double kEpsilon = 1.0 / std::numeric_limits::max(); + constexpr double kEpsilon = 1.0 / kint64max; score += kEpsilon * delta * evaluator_->ObjectiveCoefficient(var); return score; } @@ -621,8 +651,8 @@ std::pair FeasibilityJumpSolver::ComputeLinearJump(int var) { // Point p1 is improving. Look for best before it. // Note that we can exclude all point after current_value since it is // worse and we assume convexity. - const Domain dom = var_domains[var].IntersectionWith( - Domain(std::numeric_limits::min(), p1 - 1)); + const Domain dom = + var_domains[var].IntersectionWith(Domain(kint64min, p1 - 1)); if (dom.IsEmpty()) { best_jump = {p1, v1}; } else { @@ -643,8 +673,8 @@ std::pair FeasibilityJumpSolver::ComputeLinearJump(int var) { if (v2 < 0.0) { // Point p2 is improving. Look for best after it. // Similarly, we exclude the other points by convexity. - const Domain dom = var_domains[var].IntersectionWith( - Domain(p2 + 1, std::numeric_limits::max())); + const Domain dom = + var_domains[var].IntersectionWith(Domain(p2 + 1, kint64max)); if (dom.IsEmpty()) { best_jump = {p2, v2}; } else { diff --git a/ortools/sat/feasibility_jump.h b/ortools/sat/feasibility_jump.h index 66c255d9622..e2d8229d924 100644 --- a/ortools/sat/feasibility_jump.h +++ b/ortools/sat/feasibility_jump.h @@ -32,6 +32,7 @@ #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/constraint_violation.h" #include "ortools/sat/cp_model_copy.h" #include "ortools/sat/integer_base.h" @@ -41,6 +42,7 @@ #include "ortools/sat/subsolver.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" +#include "ortools/util/random_engine.h" #include "ortools/util/sorted_interval_list.h" #include "ortools/util/time_limit.h" @@ -177,14 +179,14 @@ struct LsOptions { use_objective == o.use_objective; } - void Randomize(const SatParameters& params, ModelRandomGenerator* random) { + void Randomize(const SatParameters& params, absl::BitGenRef random) { perturbation_probability = - absl::Bernoulli(*random, 0.5) + absl::Bernoulli(random, 0.5) ? 0.0 : params.feasibility_jump_var_randomization_probability(); - use_decay = absl::Bernoulli(*random, 0.5); - use_compound_moves = absl::Bernoulli(*random, 0.5); - use_objective = absl::Bernoulli(*random, 0.5); + use_decay = absl::Bernoulli(random, 0.5); + use_compound_moves = absl::Bernoulli(random, 0.5); + use_objective = absl::Bernoulli(random, 0.5); } }; @@ -245,7 +247,7 @@ struct LsState { int64_t num_batches_before_change = 0; // Used by LS to know the rank of the starting solution for this state. - int64_t last_solution_rank = std::numeric_limits::max(); + int64_t last_solution_rank = kint64max; // Tricky: If this changed since last time, we need to recompute the // compound moves as the objective constraint bound changed. @@ -390,7 +392,7 @@ class FeasibilityJumpSolver : public SubSolver { shared_response_(shared_response), shared_hints_(shared_hints), stat_tables_(stat_tables), - random_(params_) { + random_(params_.random_seed()) { shared_time_limit_->UpdateLocalLimit(&time_limit_); } @@ -512,7 +514,7 @@ class FeasibilityJumpSolver : public SubSolver { SharedResponseManager* shared_response_; SharedLsSolutionRepository* shared_hints_; SharedStatTables* stat_tables_; - ModelRandomGenerator random_; + random_engine_t random_; // Whether each `dense_model_` variable occurs in a positive/negative term in // the objective. diff --git a/ortools/sat/feasibility_pump.cc b/ortools/sat/feasibility_pump.cc index 3da5ba15fb4..c06a3cfcbc2 100644 --- a/ortools/sat/feasibility_pump.cc +++ b/ortools/sat/feasibility_pump.cc @@ -24,6 +24,7 @@ #include "absl/log/check.h" #include "absl/meta/type_traits.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/glop/parameters.pb.h" #include "ortools/glop/revised_simplex.h" #include "ortools/glop/status.h" @@ -705,15 +706,12 @@ void FeasibilityPump::FillIntegerSolutionStats() { for (const auto& term : integer_lp_[i].terms) { const int64_t prod = CapProd(integer_solution_[term.first.value()], term.second.value()); - if (prod <= std::numeric_limits::min() || - prod >= std::numeric_limits::max()) { + if (prod <= kint64min || prod >= kint64max) { activity = prod; break; } activity = CapAdd(activity, prod); - if (activity <= std::numeric_limits::min() || - activity >= std::numeric_limits::max()) - break; + if (activity <= kint64min || activity >= kint64max) break; } if (activity > integer_lp_[i].ub || activity < integer_lp_[i].lb) { integer_solution_is_feasible_ = false; diff --git a/ortools/sat/gate_congruence_closure.cc b/ortools/sat/gate_congruence_closure.cc index 424ac5a1175..3a3c3e9e69f 100644 --- a/ortools/sat/gate_congruence_closure.cc +++ b/ortools/sat/gate_congruence_closure.cc @@ -34,6 +34,7 @@ #include "ortools/base/log_severity.h" #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/graph_base/connected_components.h" #include "ortools/sat/clause.h" #include "ortools/sat/gate_utils.h" @@ -258,7 +259,7 @@ void GateCongruenceClosure::ExtractAndGatesAndFillShortTruthTables( } // Used for an optimization below. - int min_num_implications = std::numeric_limits::max(); + int min_num_implications = kint32max; Literal lit_with_less_implications; const int clause_size = clause->size(); @@ -1690,6 +1691,51 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { } } + if (params_.inprocessing_detect_and_sweep_circuit()) { + // Use lower index as representative. + util_intops::StrongVector lowest_rep( + num_variables, kNoLiteralIndex); + for (BooleanVariable var(0); var < num_variables; ++var) { + const Literal rep1 = + lrat_helper.GetRepresentativeWithProofSupport(Literal(var, true)); + if (rep1.Variable() == var) continue; + + Literal current_rep = rep1; + if (lowest_rep[rep1.Variable()] != kNoLiteralIndex) { + current_rep = Literal(lowest_rep[rep1.Variable()]); + if (rep1.IsNegative()) { + current_rep = current_rep.Negated(); + } + } + + if (current_rep.Variable() > var) { + if (lowest_rep[current_rep.Variable()] == kNoLiteralIndex) { + lowest_rep[current_rep.Variable()] = + Literal(var, current_rep.IsPositive()).Index(); + } + } else if (lowest_rep[var] == kNoLiteralIndex) { + lowest_rep[var] = current_rep.Index(); + } + } + + // Test the code above. + // TODO(user): unit test instead. + if (DEBUG_MODE) { + for (BooleanVariable var(0); var < num_variables; ++var) { + if (lowest_rep[var] != kNoLiteralIndex) { + CHECK_LT(Literal(lowest_rep[var]).Variable(), var); + CHECK_EQ( + lrat_helper.GetRepresentativeWithProofSupport( + Literal(lowest_rep[var])), + lrat_helper.GetRepresentativeWithProofSupport(Literal(var, true))) + << var; + } + } + } + + ExploitCircuitStructure(lowest_rep); + } + if (DEBUG_MODE) { CHECK_EQ(num_processed_fixed_variables, trail_->Index()); CHECK(queue.empty()); @@ -1737,5 +1783,130 @@ bool GateCongruenceClosure::DoOneRound(bool log_info) { return true; } +void GateCongruenceClosure::ExploitCircuitStructure( + const util_intops::StrongVector& + lowest_rep) { + const auto get_rep = [&lowest_rep](Literal lit) { + if (lowest_rep[lit.Variable()] == kNoLiteralIndex) return lit; + return lit.IsPositive() ? Literal(lowest_rep[lit.Variable()]) + : Literal(lowest_rep[lit.Variable()]).Negated(); + }; + + // Compute a "defining" gate for each Boolean. + const int num_variables(sat_solver_->NumVariables()); + util_intops::StrongVector defining_id(num_variables, + GateId(-1)); + const int num_gates = gates_inputs_.size(); + for (GateId id(0); id < num_gates; ++id) { + if (gates_type_[id] == kAndGateType) continue; + if (gates_inputs_[id].size() != 2) continue; + if (assignment_.LiteralIsAssigned(Literal(gates_target_[id]))) continue; + + const Literal target = get_rep(gates_target_[id]); + const Literal a = get_rep(gates_inputs_[id][0]); + const Literal b = get_rep(gates_inputs_[id][1]); + if (a.Index() >= target.Index()) continue; + if (b.Index() >= target.Index()) continue; + + if (defining_id[target.Variable()] < 0) { + defining_id[target.Variable()] = id; + } + } + + // Assume topological order follow the variable ones, and reconstruct a + // circuit. This is why we used the "lowest" representative, to try to keep + // this topological order. + util_intops::StrongVector is_used(num_variables, + false); + util_intops::StrongVector is_defined(num_variables, + false); + int num_defining_gates = 0; + for (BooleanVariable var(0); var < num_variables; ++var) { + if (defining_id[var] < 0) continue; + + const GateId id = defining_id[var]; + const Literal target = get_rep(gates_target_[id]); + const Literal a = get_rep(gates_inputs_[id][0]); + const Literal b = get_rep(gates_inputs_[id][1]); + + ++num_defining_gates; + is_used[a.Variable()] = true; + is_used[b.Variable()] = true; + is_defined[target.Variable()] = true; + } + + // Recover the inputs. + BinaryCircuit circuit; + circuit.mapping.assign(num_variables, -1); + for (BooleanVariable var(0); var < num_variables; ++var) { + if (is_used[var] && !is_defined[var]) { + circuit.mapping[var] = circuit.reverse_mapping.size(); + circuit.reverse_mapping.push_back(var); + } + } + circuit.num_inputs = circuit.reverse_mapping.size(); + + // Recover the gates. + for (BooleanVariable var(0); var < num_variables; ++var) { + if (defining_id[var] < 0) continue; + + const GateId id = defining_id[var]; + SmallBitset type = gates_type_[id]; + Literal target = get_rep(gates_target_[id]); + if (target.IsNegative()) { + target = target.Negated(); + type ^= 0b1111; + } + CHECK_EQ(target.Variable(), var); + + const Literal a = get_rep(gates_inputs_[id][0]); + const Literal b = get_rep(gates_inputs_[id][1]); + if (circuit.mapping[a.Variable()] == -1) continue; + if (circuit.mapping[b.Variable()] == -1) continue; + + // Change the type so that a and b are positive. + int swap = 0; + if (!a.IsPositive()) swap |= 1; + if (!b.IsPositive()) swap |= 2; + if (swap != 0) { + SmallBitset new_type = 0; + for (int i = 0; i < 4; ++i) { + new_type |= ((type >> i) & 1) << (i ^ swap); + } + type = new_type; + } + + CHECK_EQ(circuit.mapping[var], -1); + circuit.mapping[var] = circuit.reverse_mapping.size(); + circuit.reverse_mapping.push_back(var); + + circuit.gates.emplace_back(type, circuit.mapping[var], + circuit.mapping[a.Variable()], + circuit.mapping[b.Variable()]); + if (!is_used[var]) { + circuit.outputs.push_back(circuit.mapping[var]); + } + } + circuit.num_vars = circuit.reverse_mapping.size(); + + VLOG(2) << "============================== "; + VLOG(2) << circuit.DebugString(); + ReduceGates(&circuit); // Make the sampling more efficient. + VLOG(2) << circuit.DebugString(); + + auto new_equiv = + SimplifyCircuit(/*max_num_solve=*/50, random_, solve_cp_model_callback_, + &saved_sampled_solutions_, &circuit); + VLOG(2) << circuit.DebugString(); + + gtl::STLSortAndRemoveDuplicates(&new_equiv); + for (const auto [a, b] : new_equiv) { + implication_graph_->AddImplication(a, b); + implication_graph_->AddImplication(b, a); + } + VLOG(2) << "#EQUIVALENCES " << new_equiv.size(); + VLOG(2) << "============================== "; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/gate_congruence_closure.h b/ortools/sat/gate_congruence_closure.h index 919ce505a4b..1bd11d04fea 100644 --- a/ortools/sat/gate_congruence_closure.h +++ b/ortools/sat/gate_congruence_closure.h @@ -64,6 +64,7 @@ class GateCongruenceClosure { public: explicit GateCongruenceClosure(Model* model) : assignment_(model->GetOrCreate()->Assignment()), + params_(*model->GetOrCreate()), sat_solver_(model->GetOrCreate()), trail_(model->GetOrCreate()), implication_graph_(model->GetOrCreate()), @@ -71,7 +72,8 @@ class GateCongruenceClosure { lrat_proof_handler_(model->Mutable()), shared_stats_(model->GetOrCreate()), logger_(model->GetOrCreate()), - time_limit_(model->GetOrCreate()) {} + time_limit_(model->GetOrCreate()), + random_(*model->GetOrCreate()) {} ~GateCongruenceClosure(); @@ -81,6 +83,13 @@ class GateCongruenceClosure { // run to try to keep the structural information from the model. void EarlyGateDetection(); + // Can be used to solve SAT sweeping subproblems. + // This should be configured to be relatively fast. + void SetSolveCallback(std::function + solve_cp_model_callback) { + solve_cp_model_callback_ = solve_cp_model_callback; + } + private: DEFINE_STRONG_INDEX_TYPE(TruthTableId); @@ -155,7 +164,16 @@ class GateCongruenceClosure { // Returns its number of inputs. int CanonicalizeShortGate(GateId id); + // Infer a "circuit" from the binary gate, and using sampling see if we + // can detect some extra equivalences. + // + // TODO(user): We currently do not have LRAT support for this. + void ExploitCircuitStructure( + const util_intops::StrongVector& + lowest_rep); + const VariablesAssignment& assignment_; + const SatParameters& params_; SatSolver* sat_solver_; Trail* trail_; BinaryImplicationGraph* implication_graph_; @@ -164,6 +182,12 @@ class GateCongruenceClosure { SharedStatistics* shared_stats_; SolverLogger* logger_; TimeLimit* time_limit_; + absl::BitGenRef random_; + + // For the experimental inprocessing_detect_and_sweep_circuit. + std::vector> saved_sampled_solutions_; + std::function + solve_cp_model_callback_ = nullptr; SparseBitset marked_; SparseBitset seen_; diff --git a/ortools/sat/gate_utils.cc b/ortools/sat/gate_utils.cc index 9276f3f4f6c..9dc858fb374 100644 --- a/ortools/sat/gate_utils.cc +++ b/ortools/sat/gate_utils.cc @@ -25,11 +25,14 @@ #include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/log/vlog_is_on.h" #include "absl/numeric/bits.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/distributions.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "ortools/base/helpers.h" +#include "ortools/base/options.h" #include "ortools/base/stl_util.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/sat_base.h" @@ -38,9 +41,16 @@ namespace operations_research::sat { std::string BinaryCircuit::DebugString() const { + int max_depth = 0; + std::vector depths(num_vars); + // All these case should be easily simplifiable. int num_todo = 0; for (const BinaryGate& gate : gates) { + depths[gate.target] = std::max(depths[gate.target], depths[gate.a] + 1); + depths[gate.target] = std::max(depths[gate.target], depths[gate.b] + 1); + max_depth = std::max(max_depth, depths[gate.target]); + if (gate.a == gate.b || gate.type == 0b0000 || gate.type == 0b1111 || gate.type == 0b0101 || gate.type == 0b1010 || gate.type == 0b0011 || gate.type == 0b1100) { @@ -50,8 +60,9 @@ std::string BinaryCircuit::DebugString() const { return absl::StrCat("#inputs:", num_inputs, " #vars:", num_vars, " #gates:", gates.size(), " #outputs:", outputs.size(), - " #constraints:", gates.size() - (num_vars - num_inputs), - " #simplifiable:", num_todo); + " #constraints:", + static_cast(gates.size()) - (num_vars - num_inputs), + " #simplifiable:", num_todo, " #depth:", max_depth); } void RemoveConstraints(BinaryCircuit* circuit) { @@ -652,6 +663,7 @@ void ReduceGates(BinaryCircuit* circuit) { circuit->gates[new_gate_index].type = new_type; circuit->gates[new_gate_index].a = queue[0]; circuit->gates[new_gate_index].b = queue[1]; + circuit->gates[new_gate_index].Simplify(); } } } @@ -671,7 +683,7 @@ std::string ToDotFile(const BinaryCircuit& circuit, std::vector out_degree(circuit.num_vars, 0); std::vector num_def(circuit.num_vars, 0); std::vector types(circuit.num_vars, 0); - std::vector> reverse_arcs; + CompactVectorVectorBuilder dependency_builder; for (const BinaryGate& gate : circuit.gates) { if (gate.target == BinaryGate::kConstraintTarget) continue; @@ -681,14 +693,14 @@ std::string ToDotFile(const BinaryCircuit& circuit, num_def[gate.target]++; if (gate.a < gate.target && gate.b < gate.target) { - reverse_arcs.push_back({gate.target, gate.a}); + dependency_builder.Add(gate.target, gate.a); if (gate.a != gate.b) { - reverse_arcs.push_back({gate.target, gate.b}); + dependency_builder.Add(gate.target, gate.b); } } } - CompactVectorVector dependency; - dependency.ResetFromPairs(reverse_arcs, circuit.num_vars); + const CompactVectorVector dependency(dependency_builder, + circuit.num_vars); std::vector nodes; std::vector> arcs; @@ -805,17 +817,17 @@ std::string ToBenchFile(const BinaryCircuit& circuit) { SubcircuitExtractor::SubcircuitExtractor(const BinaryCircuit& circuit) : mitter_(circuit) { // Do some precomputation. - std::vector> reverse_arcs; + CompactVectorVectorBuilder dependency_builder; for (const BinaryGate& gate : circuit.gates) { if (gate.target == BinaryGate::kConstraintTarget) continue; if (gate.a < gate.target && gate.b < gate.target) { - reverse_arcs.push_back({gate.target, gate.a}); + dependency_builder.Add(gate.target, gate.a); if (gate.a != gate.b) { - reverse_arcs.push_back({gate.target, gate.b}); + dependency_builder.Add(gate.target, gate.b); } } } - dependency_.ResetFromPairs(reverse_arcs, circuit.num_vars); + dependency_.ResetFromBuilder(dependency_builder, circuit.num_vars); } BinaryCircuit SubcircuitExtractor::Extract(absl::Span literals) { @@ -969,7 +981,7 @@ BinaryCircuit ConstructMitter(const BinaryCircuit& circuit_a, // TODO(user): If one call proved all potential equivalences, we can stop. // TODO(user): congruence closure is faster... resuse sat code somehow? -void SimplifyCircuit( +std::vector> SimplifyCircuit( int max_num_solves, absl::BitGenRef random, std::function solve, std::vector>* saved_solutions, @@ -1000,7 +1012,7 @@ void SimplifyCircuit( if (lmp.num_inputs <= 20) { AddNotEquivalentConstraint(a, b, &lmp); const bool are_equivalent = !BinaryCircuitIsFeasible(lmp); - VLOG(3) << lmp.DebugString() << " equiv: " << are_equivalent << " " << a + VLOG(2) << lmp.DebugString() << " equiv: " << are_equivalent << " " << a << " " << b; if (are_equivalent) new_equiv.push_back({a, b}); continue; @@ -1020,7 +1032,7 @@ void SimplifyCircuit( int num_displayed = 0; std::vector solution; for (int i = 0; i < complexity.size(); ++i) { - if (++num_tried <= max_num_solves) { + if (solve != nullptr && ++num_tried <= max_num_solves) { BinaryCircuit lmp = extractor.Extract({complexity[i].a, complexity[i].b}); AddNotEquivalentConstraint(complexity[i].a, complexity[i].b, &lmp); local_cp_model = ConstructCpModelFromBinaryCircuit(lmp); @@ -1041,9 +1053,26 @@ void SimplifyCircuit( } } saved_solutions->push_back(solution); + } else { + if (VLOG_IS_ON(2)) { + // Dump info for investigation. + const std::string dot_filename = + absl::StrCat("/tmp/dot_unclear_", i, ".dot"); + VLOG(2) << "Dumping to '" << dot_filename << "'"; + CHECK_OK(file::SetContents(dot_filename, ToDotFile(lmp), + file::Defaults())); + + std::string filename = + absl::StrCat("/tmp/submodel_unclear_", i, ".pb.txt"); + VLOG(2) << " Dumping equiv checking submodel to '" << filename << "'"; + CHECK(WriteModelProtoToFile(local_cp_model, filename)); + } + + // Lets disable sat subsolve as soon as we can't solve one. + num_tried = max_num_solves; } - VLOG(3) << i + 1 << "/" << complexity.size() << " " << lmp.DebugString() + VLOG(2) << i + 1 << "/" << complexity.size() << " " << lmp.DebugString() << " equiv: " << proven_equiv << " (with solver) " << complexity[i].a << " " << complexity[i].b; if (proven_equiv) { @@ -1054,9 +1083,9 @@ void SimplifyCircuit( if (++num_displayed <= 5 || i + 5 >= complexity.size()) { if (i + 5 == complexity.size()) { - VLOG(3) << "..."; + VLOG(2) << "..."; } else { - VLOG(3) << "vars " << complexity[i].num_vars << " inputs " + VLOG(2) << "vars " << complexity[i].num_vars << " inputs " << complexity[i].num_inputs; } } @@ -1065,6 +1094,7 @@ void SimplifyCircuit( VLOG(3) << "NEW equivalences" << new_equiv.size(); RemoveEquivalences(new_equiv, circuit); ReduceGates(circuit); + return new_equiv; } void RemoveEquivalences(absl::Span> equiv, @@ -1193,6 +1223,13 @@ void RemoveEquivalences(absl::Span> equiv, } } + // Remap outputs that are equal to their representative. + for (int& out_ref : circuit->outputs) { + if (representative[out_ref] == kNoLiteralIndex) continue; + const Literal lit(representative[out_ref]); + if (lit.IsPositive()) out_ref = circuit->mapping[lit.Variable()]; + } + if (num_extra_equivalences > 0) { VLOG(2) << "num extra equivalences = " << num_extra_equivalences; } diff --git a/ortools/sat/gate_utils.h b/ortools/sat/gate_utils.h index 02236f1d3f3..6282a717f7e 100644 --- a/ortools/sat/gate_utils.h +++ b/ortools/sat/gate_utils.h @@ -24,6 +24,7 @@ #include #include "absl/log/check.h" +#include "absl/numeric/bits.h" #include "absl/random/bit_gen_ref.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" @@ -222,7 +223,7 @@ inline int CanonicalizeFunctionTruthTable(Literal& target, const int new_size = FullyCanonicalizeTruthTable(inputs, function_values); // If we have x = f(a,b,c) and not(y) = f(a,b,c) with the same f, we have an - // equivalence, so we need to canonicalicpze both f() and not(f()) to the same + // equivalence, so we need to canonicalize both f() and not(f()) to the same // function. For that we just always choose to have the lowest bit at zero. if (function_values & 1) { target = target.Negated(); @@ -329,6 +330,15 @@ struct BinaryGate { } } + template + friend H AbslHashValue(H h, const BinaryGate& g) { + return H::combine(std::move(h), g.type, g.target, g.a, g.b); + } + + bool operator==(const BinaryGate& o) const { + return type == o.type && target == o.target && a == o.a && b == o.b; + } + // value[target] = (type >> (value[a] + 2 * value[b])) & 1. SmallBitset type = 0; int target = 0; @@ -423,9 +433,9 @@ CompactVectorVector SampleForEquivalences( const BinaryCircuit& circuit, absl::BitGenRef random, const std::vector>& saved_solutions); -// Find equivalences using sampling, and then proove using either exhaustive +// Find equivalences using sampling, and then prove using either exhaustive // enumeration or sat solving via the solve() function. -void SimplifyCircuit( +std::vector> SimplifyCircuit( int max_num_solve, absl::BitGenRef random, std::function solve, std::vector>* saved_solutions, diff --git a/ortools/sat/inclusion.h b/ortools/sat/inclusion.h index 05a70dbdd00..1cba2e280c6 100644 --- a/ortools/sat/inclusion.h +++ b/ortools/sat/inclusion.h @@ -20,13 +20,13 @@ #include #include #include -#include #include #include #include #include "absl/log/check.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/util.h" #include "ortools/util/bitset.h" #include "ortools/util/time_limit.h" @@ -152,7 +152,7 @@ class InclusionDetector { int num_potential_subsets_ = 0; int num_potential_supersets_ = 0; uint64_t work_done_ = 0; - uint64_t work_limit_ = std::numeric_limits::max(); + uint64_t work_limit_ = kuint64max; bool stop_ = false; bool stop_with_current_subset_ = false; @@ -205,7 +205,7 @@ class SubsetsDetector { TimeLimit* time_limit_; uint64_t work_done_ = 0; - uint64_t work_limit_ = std::numeric_limits::max(); + uint64_t work_limit_ = kuint64max; struct OneWatcherData { int index; @@ -426,8 +426,7 @@ inline void SubsetsDetector::IndexAllStorageAsSubsets() { stop_ = false; // Flat representation of one_watcher_, we will fill it in one go from there. - std::vector tmp_keys; - std::vector tmp_values; + CompactVectorVectorBuilder one_watcher_builder; std::vector element_to_num_watched; work_done_ = 0; @@ -469,11 +468,10 @@ inline void SubsetsDetector::IndexAllStorageAsSubsets() { DCHECK_NE(best_choice, second_choice); element_to_num_watched[best_choice]++; - tmp_keys.push_back(best_choice); - tmp_values.push_back({index, second_choice, signature}); + one_watcher_builder.Add(best_choice, {index, second_choice, signature}); } - one_watcher_.ResetFromFlatMapping(tmp_keys, tmp_values); + one_watcher_.ResetFromBuilder(one_watcher_builder); } template diff --git a/ortools/sat/integer.cc b/ortools/sat/integer.cc index 41f91ebbce2..5714b144f02 100644 --- a/ortools/sat/integer.cc +++ b/ortools/sat/integer.cc @@ -35,6 +35,7 @@ #include "absl/types/span.h" #include "ortools/base/log_severity.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" @@ -76,6 +77,7 @@ void IntegerEncoder::ReserveSpaceForNumVariables(int num_vars) { encoding_by_var_.reserve(num_vars); equality_to_associated_literal_.reserve(num_vars); equality_by_var_.reserve(num_vars); + is_fully_encoded_.reserve(num_vars); } void IntegerEncoder::FullyEncodeVariable(IntegerVariable var) { @@ -256,7 +258,7 @@ std::pair IntegerEncoder::Canonicalize( IntegerValue before(i_lit.bound - 1); DCHECK_GE(before, domains_[index].Min()); DCHECK_LE(after, domains_[index].Max()); - int64_t previous = std::numeric_limits::min(); + int64_t previous = kint64min; for (const ClosedInterval& interval : domains_[index]) { if (before > previous && before < interval.start) before = previous; if (after > previous && after < interval.start) after = interval.start; @@ -868,11 +870,16 @@ IntegerVariable IntegerTrail::AddIntegerVariable(IntegerValue lower_bound, DCHECK_GE(lower_bound, kMinIntegerValue); DCHECK_LE(lower_bound, upper_bound); DCHECK_LE(upper_bound, kMaxIntegerValue); - DCHECK(lower_bound >= 0 || - lower_bound + std::numeric_limits::max() >= upper_bound); + DCHECK(lower_bound >= 0 || lower_bound + kint64max >= upper_bound); DCHECK(integer_search_levels_.empty()); DCHECK_EQ(var_lbs_.size(), integer_trail_.size()); + // This is needed if we create integer variable AFTER we did some propagation + // to make sure the extra info for any new variable is clean. + if (extra_trail_info_.size() > integer_trail_.size()) { + extra_trail_info_.resize(integer_trail_.size()); + } + const IntegerVariable i(var_lbs_.size()); var_lbs_.push_back(lower_bound); var_trail_index_.push_back(integer_trail_.size()); @@ -883,6 +890,7 @@ IntegerVariable IntegerTrail::AddIntegerVariable(IntegerValue lower_bound, var_lbs_.push_back(-upper_bound); var_trail_index_.push_back(integer_trail_.size()); integer_trail_.push_back({-upper_bound, NegationOf(i)}); + CHECK_LT(integer_trail_.size(), kint32max); var_trail_index_cache_.resize(var_lbs_.size(), integer_trail_.size()); tmp_var_to_trail_index_in_queue_.resize(var_lbs_.size(), 0); @@ -1670,6 +1678,7 @@ ReasonIndex IntegerTrail::AppendReasonToInternalBuffers( literals_reason_buffer_.insert(literals_reason_buffer_.end(), literal_reason.begin(), literal_reason.end()); + CHECK_LT(literals_reason_buffer_.size(), kint32max); } cached_sizes_.push_back(-1); @@ -1677,6 +1686,7 @@ ReasonIndex IntegerTrail::AppendReasonToInternalBuffers( if (!integer_reason.empty()) { bounds_reason_buffer_.insert(bounds_reason_buffer_.end(), integer_reason.begin(), integer_reason.end()); + CHECK_LT(bounds_reason_buffer_.size(), kint32max); } return reason_index; @@ -1896,16 +1906,25 @@ void IntegerTrail::PushOnTrail(IntegerLiteral i_lit, int prev_trail_index, int bool_trail_index, ReasonIndex reason_index, int assignment_level) { const int i = integer_trail_.size(); + DCHECK_GE(prev_trail_index, 0); + DCHECK_LT(prev_trail_index, i); + integer_trail_.push_back({/*bound=*/i_lit.bound, /*var=*/i_lit.var, /*prev_trail_index=*/prev_trail_index, /*reason_index=*/reason_index}); + CHECK_LT(integer_trail_.size(), kint32max); if (!new_conflict_resolution_) return; if (extra_trail_info_.size() < integer_trail_.size()) { extra_trail_info_.resize(integer_trail_.size()); } + // The bool_trail_index should be non-decreasing between the previous trail + // entry on the same variable, otherwise we have an issue with the conflict + // resolution code. + DCHECK_GE(bool_trail_index, + extra_trail_info_[prev_trail_index].bool_trail_index); extra_trail_info_[i] = {/*assignment_level=*/assignment_level, /*bool_trail_index=*/bool_trail_index}; } @@ -1956,15 +1975,32 @@ bool IntegerTrail::EnqueueAssociatedIntegerLiteral(IntegerLiteral i_lit, const ReasonIndex reason_index = AppendReasonToInternalBuffers({literal_reason.Negated()}, {}); - const int prev_trail_index = var_trail_index_[i_lit.var]; + int prev_trail_index = var_trail_index_[i_lit.var]; var_lbs_[i_lit.var] = i_lit.bound; var_trail_index_[i_lit.var] = integer_trail_.size(); // We use as a boolean trail_index the next one after the one of // literal_reason. - PushOnTrail(i_lit, prev_trail_index, - trail_->Info(literal_reason.Variable()).trail_index + 1, - reason_index, trail_->AssignmentLevel(literal_reason)); + // + // Important: This is rare, but it can happen that this stronger bound was + // derived using a lower bool_trail_index than the current best bound we have. + // In this case, to ensure correctness of the conflict resolution code, we + // will clear all such "dominated" previous entry. + const int bool_trail_index = + trail_->Info(literal_reason.Variable()).trail_index + 1; + if (new_conflict_resolution_) { + const int num_vars = var_lbs_.size(); + while (prev_trail_index >= num_vars && + prev_trail_index < extra_trail_info_.size() && + extra_trail_info_[prev_trail_index].bool_trail_index > + bool_trail_index) { + integer_trail_[prev_trail_index].var = kNoIntegerVariable; + prev_trail_index = integer_trail_[prev_trail_index].prev_trail_index; + } + } + + PushOnTrail(i_lit, prev_trail_index, bool_trail_index, reason_index, + trail_->AssignmentLevel(literal_reason)); return true; } @@ -1994,7 +2030,7 @@ void IntegerTrail::ComputeLazyReasonIfNeeded(ReasonIndex index) const { // mainly save a FindLowestTrailIndexThatExplainBound() call per skipped // indices, which can still be costly. const int index = tmp_var_to_trail_index_in_queue_[i_lit.var]; - if (index == std::numeric_limits::max()) continue; + if (index == kint32max) continue; if (index > 0 && integer_trail_[index].bound >= i_lit.bound) { has_dependency_ = true; continue; diff --git a/ortools/sat/integer.h b/ortools/sat/integer.h index 974ce3bed9a..049aa628bd1 100644 --- a/ortools/sat/integer.h +++ b/ortools/sat/integer.h @@ -34,6 +34,7 @@ #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/lrat_proof_handler.h" #include "ortools/sat/model.h" @@ -502,7 +503,7 @@ class LazyReasonInterface { // alternatively, push "place_holder" to the boolean trail. But for // non-chronological backtracking, having the assignment level here seems nice. struct GlobalTrailIndex { - constexpr static int kNoIntegerIndex = std::numeric_limits::max(); + constexpr static int kNoIntegerIndex = kint32max; int level; int bool_index; @@ -1067,7 +1068,7 @@ class IntegerTrail final : public SatPropagator { IntegerLiteral i_lit, bool use_lazy_reason, absl::Span literal_reason, absl::Span integer_reason, - int trail_index_with_same_reason = std::numeric_limits::max()); + int trail_index_with_same_reason = kint32max); // Internal implementation of the EnqueueLiteral() functions. ABSL_MUST_USE_RESULT bool EnqueueLiteralInternal( @@ -1148,8 +1149,8 @@ class IntegerTrail final : public SatPropagator { // was added. This is only used when use_new_integer_conflict_resolution is // true. struct ExtraTrailEntry { - int32_t assignment_level; - int32_t bool_trail_index; + int32_t assignment_level = 0; + int32_t bool_trail_index = 0; }; std::vector extra_trail_info_; diff --git a/ortools/sat/integer_base.cc b/ortools/sat/integer_base.cc index 2575d49d732..f0487915412 100644 --- a/ortools/sat/integer_base.cc +++ b/ortools/sat/integer_base.cc @@ -16,12 +16,13 @@ #include #include #include +#include #include #include #include #include "absl/log/check.h" -#include "ortools/util/bitset.h" +#include "ortools/base/mathutil.h" namespace operations_research::sat { @@ -175,6 +176,35 @@ AffineExpression LinearExpression2::GetAffineLowerBound( CeilRatio(nominator, coeff)); } +std::optional LinearExpression2::GetDifferenceLowerBound( + IntegerValue lb, AffineExpression t2, AffineExpression t1) { + DCHECK_EQ(vars[0], NegationOf(t1.var)); + DCHECK_EQ(vars[1], t2.var); + DCHECK_GT(coeffs[0], 0); + DCHECK_GT(coeffs[1], 0); + DCHECK_GT(t1.coeff, 0); + DCHECK_GT(t2.coeff, 0); + // We have + // a.x + b.y >= lb + // t1 = c.(-x) + d + // t2 = e.y + f + // with a, b, c, d > 0. This can be rewritten as: + // -(a / c) * (t1 - d) + (b / e) * (t2 - f) >= lb + // If a / c and b / e are integer and equal to k, then we get: + // -k * (t1 - d) + k * (t2 - f) >= lb + // which yields: + // t2 >= t1 + (lb / k) + f - d + if (coeffs[0] % t1.coeff != 0 || coeffs[1] % t2.coeff != 0) { + return std::nullopt; + } + const IntegerValue ke = coeffs[0] / t1.coeff; + const IntegerValue ks = coeffs[1] / t2.coeff; + if (ks != ke) { + return std::nullopt; + } + return MathUtil::CeilOfRatio(lb, ks) + t2.constant - t1.constant; +} + void LinearExpression2::MakeVariablesPositive() { SimpleCanonicalization(); for (int i = 0; i < 2; ++i) { @@ -272,4 +302,21 @@ BestBinaryRelationBounds::GetSortedNonTrivialBounds() const { return root_relations_sorted; } +std::pair BestBinaryRelationBounds::GetBounds( + LinearExpression2 expr) const { + expr.SimpleCanonicalization(); + const IntegerValue gcd = expr.DivideByGcd(); + const bool negated = expr.NegateForCanonicalization(); + const auto it = best_bounds_.find(expr); + if (it != best_bounds_.end()) { + const auto [known_lb, known_ub] = it->second; + if (negated) { + return {CapProdI(-known_ub, gcd), CapProdI(-known_lb, gcd)}; + } else { + return {CapProdI(known_lb, gcd), CapProdI(known_ub, gcd)}; + } + } + return {kMinIntegerValue, kMaxIntegerValue}; +} + } // namespace operations_research::sat diff --git a/ortools/sat/integer_base.h b/ortools/sat/integer_base.h index d63cae3cf06..4ac6cb30c58 100644 --- a/ortools/sat/integer_base.h +++ b/ortools/sat/integer_base.h @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include #include @@ -405,6 +405,14 @@ struct LinearExpression2 { AffineExpression GetAffineLowerBound(int var_index, IntegerValue expr_lb, IntegerValue other_var_lb) const; + // If `this` >= `lb` is of the form k*(t2 - t1) >= k*delta_t with k > 0, + // returns delta_t. Otherwise, returns nullopt. The variables in `t1` and `t2` + // must be `NegationOf(vars[0])` and `vars[1]`, respectively. coeffs[0] and + // coeffs[1] must be positive. + std::optional GetDifferenceLowerBound(IntegerValue lb, + AffineExpression t2, + AffineExpression t1); + // Divides the expression by the gcd of both coefficients, and returns it. // Note that we always return something >= 1 even if both coefficients are // zero. @@ -510,6 +518,8 @@ class BestBinaryRelationBounds { std::pair Add(LinearExpression2 expr, IntegerValue lb, IntegerValue ub); + std::pair GetBounds(LinearExpression2 expr) const; + // Returns the known status of expr <= bound. RelationStatus GetStatus(LinearExpression2 expr, IntegerValue lb, IntegerValue ub) const; diff --git a/ortools/sat/integer_base_test.cc b/ortools/sat/integer_base_test.cc index 5285b4e383c..26cf01533e6 100644 --- a/ortools/sat/integer_base_test.cc +++ b/ortools/sat/integer_base_test.cc @@ -15,6 +15,7 @@ #include #include +#include #include #include "absl/log/check.h" @@ -204,5 +205,38 @@ TEST(Linear2BoundAffineRelaxationTest, Random) { } } +TEST(LinearExpression2Test, GetDifferenceLowerBound) { + const IntegerVariable x = IntegerVariable(0); + const IntegerVariable y = IntegerVariable(2); + + // y - x >= 5 => y - x >= 5 + EXPECT_EQ(LinearExpression2(NegationOf(x), y, 1, 1) + .GetDifferenceLowerBound(IntegerValue(5), AffineExpression(y), + AffineExpression(x)), + IntegerValue(5)); + // 2y - 2x >= 5 => y - x >= 5/2 => y - x >= 3 + EXPECT_EQ(LinearExpression2(NegationOf(x), y, 2, 2) + .GetDifferenceLowerBound(IntegerValue(5), AffineExpression(y), + AffineExpression(x)), + IntegerValue(3)); + // 2y - 2x >= 5 => y - x >= 5/2 => (y+2) - (x+4) >= 3 + 2 - 4 = 1 + EXPECT_EQ( + LinearExpression2(NegationOf(x), y, 2, 2) + .GetDifferenceLowerBound(IntegerValue(5), AffineExpression(y, 1, 2), + AffineExpression(x, 1, 4)), + IntegerValue(1)); + // 2y - x >= 5 implies nothing about y - x. + EXPECT_EQ(LinearExpression2(NegationOf(x), y, 1, 2) + .GetDifferenceLowerBound(IntegerValue(5), AffineExpression(y), + AffineExpression(x)), + std::nullopt); + // 3y - 2x >= 5 implies nothing about 2y - 2x. + EXPECT_EQ( + LinearExpression2(NegationOf(x), y, 2, 3) + .GetDifferenceLowerBound(IntegerValue(5), AffineExpression(y, 2), + AffineExpression(x, 2)), + std::nullopt); +} + } // namespace } // namespace operations_research::sat diff --git a/ortools/sat/integer_expr.cc b/ortools/sat/integer_expr.cc index 33035a062c8..77afab52b13 100644 --- a/ortools/sat/integer_expr.cc +++ b/ortools/sat/integer_expr.cc @@ -27,6 +27,7 @@ #include "absl/numeric/int128.h" #include "absl/types/span.h" #include "ortools/base/mathutil.h" +#include "ortools/base/types.h" #include "ortools/sat/enforcement.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" @@ -295,7 +296,7 @@ bool LinearConstraintPropagator::Propagate() { // If use_int128 is true, the slack or propagation slack can be larger than // this. To detect overflow with capped arithmetic, it is important the slack // used in our algo never exceed this value. - const absl::int128 max_slack = std::numeric_limits::max() - 1; + const absl::int128 max_slack = kint64max - 1; // Conflict? IntegerValue slack; @@ -1674,14 +1675,19 @@ bool FixedModuloPropagator::Propagate() { if (status != EnforcementStatus::IS_ENFORCED) return true; if (!PropagateSignsAndTargetRange()) return false; - if (!PropagateOuterBounds()) return false; + bool changed = true; + if (!PropagateOuterBounds(&changed)) return false; + + // Subtle: we might need to run PropagateSignsAndTargetRange() again to make + // sure that the invariant `expr >= 0 => target >= 0` is respected. + if (changed) { + if (!PropagateSignsAndTargetRange()) return false; + } if (integer_trail_.LowerBound(expr_) >= 0) { - if (!PropagateBoundsWhenExprIsNonNegative(expr_, target_)) return false; + return PropagateBoundsWhenExprIsNonNegative(expr_, target_); } else if (integer_trail_.UpperBound(expr_) <= 0) { - if (!PropagateBoundsWhenExprIsNonNegative(negated_expr_, negated_target_)) { - return false; - } + return PropagateBoundsWhenExprIsNonNegative(negated_expr_, negated_target_); } return true; @@ -1827,13 +1833,15 @@ bool FixedModuloPropagator::PropagateSignsAndTargetRange() { return true; } -bool FixedModuloPropagator::PropagateOuterBounds() { +bool FixedModuloPropagator::PropagateOuterBounds(bool* changed) { + *changed = false; const IntegerValue min_expr = integer_trail_.LowerBound(expr_); const IntegerValue max_expr = integer_trail_.UpperBound(expr_); const IntegerValue min_target = integer_trail_.LowerBound(target_); const IntegerValue max_target = integer_trail_.UpperBound(target_); if (max_expr % mod_ > max_target) { + *changed = true; if (!enforcement_helper_.SafeEnqueue( enforcement_id_, expr_.LowerOrEqual((max_expr / mod_) * mod_ + max_target), @@ -1844,6 +1852,7 @@ bool FixedModuloPropagator::PropagateOuterBounds() { } if (min_expr % mod_ < min_target) { + *changed = true; if (!enforcement_helper_.SafeEnqueue( enforcement_id_, expr_.GreaterOrEqual((min_expr / mod_) * mod_ + min_target), @@ -1854,6 +1863,7 @@ bool FixedModuloPropagator::PropagateOuterBounds() { } if (min_expr / mod_ == max_expr / mod_) { + *changed = true; if (min_target < min_expr % mod_) { if (!enforcement_helper_.SafeEnqueue( enforcement_id_, @@ -1867,6 +1877,7 @@ bool FixedModuloPropagator::PropagateOuterBounds() { } if (max_target > max_expr % mod_) { + *changed = true; if (!enforcement_helper_.SafeEnqueue( enforcement_id_, target_.LowerOrEqual(max_expr - (max_expr / mod_) * mod_), @@ -1880,6 +1891,7 @@ bool FixedModuloPropagator::PropagateOuterBounds() { } else if (min_expr / mod_ == 0 && min_target < 0) { // expr == target when expr <= 0. if (min_target < min_expr) { + *changed = true; if (!enforcement_helper_.SafeEnqueue( enforcement_id_, target_.GreaterOrEqual(min_expr), {integer_trail_.LowerBoundAsLiteral(target_), @@ -1890,6 +1902,7 @@ bool FixedModuloPropagator::PropagateOuterBounds() { } else if (max_expr / mod_ == 0 && max_target > 0) { // expr == target when expr >= 0. if (max_target > max_expr) { + *changed = true; if (!enforcement_helper_.SafeEnqueue( enforcement_id_, target_.LowerOrEqual(max_expr), {integer_trail_.UpperBoundAsLiteral(target_), diff --git a/ortools/sat/integer_expr.h b/ortools/sat/integer_expr.h index 9ab9bc1a6a9..bd71ed549c2 100644 --- a/ortools/sat/integer_expr.h +++ b/ortools/sat/integer_expr.h @@ -426,7 +426,7 @@ class FixedModuloPropagator : public PropagatorInterface { bool PropagateSignsAndTargetRange(); bool PropagateBoundsWhenExprIsNonNegative(AffineExpression expr, AffineExpression target); - bool PropagateOuterBounds(); + bool PropagateOuterBounds(bool* changed); const AffineExpression expr_; const IntegerValue mod_; diff --git a/ortools/sat/integer_expr_test.cc b/ortools/sat/integer_expr_test.cc index 0816009d97e..a987e0bea62 100644 --- a/ortools/sat/integer_expr_test.cc +++ b/ortools/sat/integer_expr_test.cc @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -33,6 +32,7 @@ #include "ortools/base/log_severity.h" #include "ortools/base/parse_test_proto.h" #include "ortools/base/parse_text_proto.h" +#include "ortools/base/types.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" @@ -1502,12 +1502,12 @@ TEST(DivisionConstraintTest, CheckAllPropagationsRandomProblem) { if (z_min > z_max) std::swap(z_min, z_max); // Loop through the domains of x and y, and collect valid bounds. - int expected_x_min = std::numeric_limits::max(); - int expected_x_max = std::numeric_limits::min(); - int expected_y_min = std::numeric_limits::max(); - int expected_y_max = std::numeric_limits::min(); - int expected_z_min = std::numeric_limits::max(); - int expected_z_max = std::numeric_limits::min(); + int expected_x_min = kint32max; + int expected_x_max = kint32min; + int expected_y_min = kint32max; + int expected_y_max = kint32min; + int expected_z_min = kint32max; + int expected_z_max = kint32min; for (int i = x_min; i <= x_max; ++i) { for (int j = y_min; j <= y_max; ++j) { const int k = i / j; @@ -1532,7 +1532,7 @@ TEST(DivisionConstraintTest, CheckAllPropagationsRandomProblem) { EXPECT_BOUNDS_EQ(var_y, expected_y_min, expected_y_max); EXPECT_BOUNDS_EQ(var_z, expected_z_min, expected_z_max); } else { - EXPECT_EQ(expected_x_max, std::numeric_limits::min()); + EXPECT_EQ(expected_x_max, kint32min); } } } @@ -1702,10 +1702,10 @@ TEST(DivisionConstraintTest, CheckAllSolutionsOnExprs) { void TestAllDivisionValues(int64_t min_a, int64_t max_a, int64_t b, int64_t min_c, int64_t max_c) { - int64_t true_min_a = std::numeric_limits::max(); - int64_t true_max_a = std::numeric_limits::min(); - int64_t true_min_c = std::numeric_limits::max(); - int64_t true_max_c = std::numeric_limits::min(); + int64_t true_min_a = kint64max; + int64_t true_max_a = kint64min; + int64_t true_min_c = kint64max; + int64_t true_max_c = kint64min; for (int64_t a = min_a; a <= max_a; ++a) { for (int64_t c = min_c; c <= max_c; ++c) { if (a / b == c) { @@ -1734,7 +1734,7 @@ void TestAllDivisionValues(int64_t min_a, int64_t max_a, int64_t b, EXPECT_EQ(integer_trail->LowerBound(var_c), true_min_c); EXPECT_EQ(integer_trail->UpperBound(var_c), true_max_c); } else { - EXPECT_EQ(true_min_a, std::numeric_limits::max()); // No solution. + EXPECT_EQ(true_min_a, kint64max); // No solution. } } @@ -1791,15 +1791,15 @@ TEST(FixedDivisionConstraintTest, ExpectedPropagation) { /*new_a=*/-8, 2, /*new_c=*/-2, 0)); // Check large domains. EXPECT_TRUE(PropagateFixedDivision( - /*a=*/0, std::numeric_limits::max() / 2, - /*b=*/5, /*c=*/3, std::numeric_limits::max() - 3, - /*new_a=*/15, std::numeric_limits::max() / 2, - /*new_c=*/3, std::numeric_limits::max() / 10)); + /*a=*/0, kint64max / 2, + /*b=*/5, /*c=*/3, kint64max - 3, + /*new_a=*/15, kint64max / 2, + /*new_c=*/3, kint64max / 10)); EXPECT_TRUE(PropagateFixedDivision( - /*a=*/0, std::numeric_limits::max() / 2, - /*b=*/5, /*c=*/3, std::numeric_limits::max() - 3, - /*new_a=*/15, std::numeric_limits::max() / 2, - /*new_c=*/3, std::numeric_limits::max() / 10)); + /*a=*/0, kint64max / 2, + /*b=*/5, /*c=*/3, kint64max - 3, + /*new_a=*/15, kint64max / 2, + /*new_c=*/3, kint64max / 10)); } TEST(FixedDivisionConstraintTest, AlwaysFalseWithUnassignedEnforcementLiteral) { @@ -1879,11 +1879,13 @@ TEST(ModuloConstraintTest, CheckAllSolutions) { LinearArgumentProto* modulo = initial_model.add_constraints()->mutable_int_mod(); modulo->add_exprs()->add_vars(0); // var. - modulo->mutable_exprs(0)->add_coeffs(1); + const int var_coeff = absl::Uniform(random, -3, 3); + modulo->mutable_exprs(0)->add_coeffs(var_coeff); modulo->add_exprs()->add_vars(1); // mod modulo->mutable_exprs(1)->add_coeffs(1); modulo->mutable_target()->add_vars(2); // target - modulo->mutable_target()->add_coeffs(1); + const int target_coeff = absl::Uniform(random, -3, 3); + modulo->mutable_target()->add_coeffs(target_coeff); absl::btree_set> solutions; const CpSolverResponse response = @@ -1892,9 +1894,18 @@ TEST(ModuloConstraintTest, CheckAllSolutions) { // Loop through the domains of var and target, and collect valid solutions. absl::btree_set> expected; for (int i = var_min; i <= var_max; ++i) { - const int k = i % mod; - if (k < target_min || k > target_max) continue; - expected.insert({i, mod, k}); + const int k = (var_coeff * i) % mod; + if (target_coeff == 0 && k != 0) continue; + if (target_coeff == 0) { + for (int j = target_min; j <= target_max; ++j) { + expected.insert({i, mod, j}); + } + continue; + } + if (k % target_coeff != 0) continue; + const int j = k / target_coeff; + if (j < target_min || j > target_max) continue; + expected.insert({i, mod, j}); } // Checks that we get we get the same solution set through the two methods. @@ -1923,10 +1934,10 @@ TEST(ModuloConstraintTest, CheckAllPropagationsRandomProblem) { if (target_min > target_max) std::swap(target_min, target_max); // Loop through the domains of var and target, and collect valid bounds. - int expected_var_min = std::numeric_limits::max(); - int expected_var_max = std::numeric_limits::min(); - int expected_target_min = std::numeric_limits::max(); - int expected_target_max = std::numeric_limits::min(); + int expected_var_min = kint32max; + int expected_var_max = kint32min; + int expected_target_min = kint32max; + int expected_target_max = kint32min; for (int i = var_min; i <= var_max; ++i) { const int k = i % mod; if (k < target_min || k > target_max) continue; @@ -1952,7 +1963,7 @@ TEST(ModuloConstraintTest, CheckAllPropagationsRandomProblem) { << model.Get(LowerBound(target)) << ".." << model.Get(UpperBound(target)) << "]"; } else { - EXPECT_EQ(expected_var_max, std::numeric_limits::min()); + EXPECT_EQ(expected_var_max, kint32min); } } } diff --git a/ortools/sat/integer_resolution.cc b/ortools/sat/integer_resolution.cc index 0feef7d0dfd..1dd3006336e 100644 --- a/ortools/sat/integer_resolution.cc +++ b/ortools/sat/integer_resolution.cc @@ -210,6 +210,7 @@ void IntegerConflictResolution::AddToQueue(GlobalTrailIndex source_index, const GlobalTrailIndex index{info.level, info.trail_index}; tmp_queue_.push_back(index); + DCHECK_LT(tmp_queue_.back(), source_index); } } for (const IntegerLiteral i_lit : reason.integer_literals) { @@ -238,6 +239,7 @@ void IntegerConflictResolution::AddToQueue(GlobalTrailIndex source_index, data.in_queue = true; tmp_queue_.push_back( integer_trail_->GlobalIndexAt(data.int_index_in_queue)); + DCHECK_LT(tmp_queue_.back(), source_index); } CHECK_LT(integer_trail_->GlobalIndexAt(data.int_index_in_queue), @@ -292,6 +294,7 @@ void IntegerConflictResolution::ProcessIntegerLiteral( data.in_queue = true; tmp_queue_.push_back( integer_trail_->GlobalIndexAt(data.int_index_in_queue)); + DCHECK_LT(tmp_queue_.back(), source_index); } data.bound = std::max(data.bound, i_lit.bound); @@ -336,8 +339,10 @@ void IntegerConflictResolution::ComputeFirstUIPConflict( const int num_i_vars = integer_trail_->NumIntegerVariables().value(); int_data_.clear(); int_data_.resize(num_i_vars); - // Note the +1 in case we create a new 1-UIP boolean. - tmp_bool_index_seen_.ClearAndResize(trail_->Index() + 1); + // Note the we need some slack because we enqueue a new decision if we see a + // boolean already assigned to true. + constexpr int kSizeSlack = 100; + tmp_bool_index_seen_.ClearAndResize(trail_->Index() + kSizeSlack); tmp_var_to_settled_lb_.assign(num_i_vars, kMinIntegerValue); tmp_queue_.clear(); @@ -393,6 +398,7 @@ void IntegerConflictResolution::ComputeFirstUIPConflict( data.int_index_in_queue = previous_index; tmp_queue_.push_back( integer_trail_->GlobalIndexAt(data.int_index_in_queue)); + DCHECK_LT(tmp_queue_.back(), top_index); CHECK_LE( data.bound, integer_trail_->IntegerLiteralAtIndex(data.int_index_in_queue) @@ -507,6 +513,9 @@ void IntegerConflictResolution::ComputeFirstUIPConflict( if (!trail_->Assignment().LiteralIsAssigned(new_lit)) { // Using a decision should work as we will backtrack right away. trail_->EnqueueSearchDecision(new_lit); + if (trail_->Index() >= tmp_bool_index_seen_.size()) { + tmp_bool_index_seen_.Resize(trail_->Index() + kSizeSlack); + } } // It should be true. @@ -607,7 +616,7 @@ void IntegerConflictResolution::ComputeFirstUIPConflict( // Expand. // - // TODO(user): There is probably a faster way to recover the heap propety + // TODO(user): There is probably a faster way to recover the heap property // than doing it one by one. const int old_size = tmp_queue_.size(); AddToQueue(top_index, diff --git a/ortools/sat/integer_resolution.h b/ortools/sat/integer_resolution.h index 8f3998c9cda..862f13939f4 100644 --- a/ortools/sat/integer_resolution.h +++ b/ortools/sat/integer_resolution.h @@ -21,6 +21,7 @@ #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/clause.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" @@ -94,7 +95,6 @@ class IntegerConflictResolution { // Information about the current content of our tmp_queue_ and our conflict // resolution. SparseBitset tmp_bool_index_seen_; - SparseBitset tmp_bool_seen_; std::vector tmp_integer_literals_; util_intops::StrongVector tmp_var_to_settled_lb_; @@ -104,7 +104,7 @@ class IntegerConflictResolution { // Whether this variable was added in the queue. // If false, index_in_queue will be the index to re-add it with. bool in_queue = false; - int int_index_in_queue = std::numeric_limits::max(); + int int_index_in_queue = kint32max; // We only need var >= bound in the current conflict resolution. // Note that we have: integer_trail_[int_index_in_queue] >= bound. diff --git a/ortools/sat/integer_search.cc b/ortools/sat/integer_search.cc index 2e621df2e06..0058be908e0 100644 --- a/ortools/sat/integer_search.cc +++ b/ortools/sat/integer_search.cc @@ -17,19 +17,25 @@ #include #include #include +#include #include #include #include +#include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/log/vlog_is_on.h" +#include "absl/random/bit_gen_ref.h" #include "absl/random/distributions.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/log_severity.h" +#include "ortools/base/strong_vector.h" #include "ortools/sat/clause.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_mapping.h" @@ -39,7 +45,9 @@ #include "ortools/sat/intervals.h" #include "ortools/sat/linear_constraint_manager.h" #include "ortools/sat/linear_programming_constraint.h" +#include "ortools/sat/linear_propagation.h" #include "ortools/sat/model.h" +#include "ortools/sat/precedences.h" #include "ortools/sat/probing.h" #include "ortools/sat/pseudo_costs.h" #include "ortools/sat/restart.h" @@ -439,9 +447,9 @@ std::function ShaveObjectiveLb(Model* model) { const IntegerVariable obj_var = objective_definition->objective_var; auto* integer_trail = model->GetOrCreate(); auto* sat_solver = model->GetOrCreate(); - auto* random = model->GetOrCreate(); + absl::BitGenRef random = *model->GetOrCreate(); - return [obj_var, integer_trail, sat_solver, random]() { + return [obj_var, integer_trail, sat_solver, rand = random]() { BooleanOrIntegerLiteral result; const int level = sat_solver->CurrentDecisionLevel(); if (level > 0 || obj_var == kNoIntegerVariable) return result; @@ -450,8 +458,9 @@ std::function ShaveObjectiveLb(Model* model) { const IntegerValue obj_ub = integer_trail->UpperBound(obj_var); if (obj_lb == obj_ub) return result; const IntegerValue mid = (obj_ub - obj_lb) / 2; + absl::BitGenRef r = rand; const IntegerValue new_ub = - obj_lb + absl::LogUniform(*random, 0, mid.value()); + obj_lb + absl::LogUniform(r, 0, mid.value()); result.integer_literal = IntegerLiteral::LowerOrEqual(obj_var, new_ub); return result; @@ -479,22 +488,242 @@ std::function PseudoCost(Model* model) { }; } +namespace { +// Detects interval precedence constraints. +class IntervalPrecedencesDetector { + public: + explicit IntervalPrecedencesDetector(Model* model) + : intervals_(*model->GetOrCreate()), + lin2_indices_(*model->GetOrCreate()), + lin2_bounds_(*model->GetOrCreate()), + conditional_bounds_(*model->GetOrCreate()) { + CompactVectorVectorBuilder + intervals_by_start_var_builder; + CompactVectorVectorBuilder + intervals_by_end_var_builder; + for (IntervalVariable interval_var(0); + interval_var < intervals_.NumIntervals(); ++interval_var) { + const AffineExpression start = intervals_.Start(interval_var); + if (start.var != kNoIntegerVariable) { + CHECK_GT(start.coeff, 0); + intervals_by_start_var_builder.Add(start.var, interval_var); + } + const AffineExpression end = intervals_.End(interval_var); + if (end.var != kNoIntegerVariable) { + CHECK_GT(end.coeff, 0); + intervals_by_end_var_builder.Add(end.var, interval_var); + } + } + intervals_by_start_var_.ResetFromBuilder(intervals_by_start_var_builder, + intervals_.NumIntervals()); + intervals_by_end_var_.ResetFromBuilder(intervals_by_end_var_builder, + intervals_.NumIntervals()); + } + + // Detects unconditional interval precedence constraints. Returns, for each + // interval I, the intervals J which should be fixed after it in the + // SchedulingSearchHeuristicHelper (this can include some precedences which + // are not in the original problem, such as artificial precedences between + // optional and non-optional intervals). + CompactVectorVector DetectPrecedences() { + CompactVectorVectorBuilder + interval_precedences_builder; + for (LinearExpression2Index i(0); + i < 2 * lin2_indices_.NumStoredPositiveLinear2(); ++i) { + LinearExpression2 expr = lin2_indices_.GetExpression(i); + if (expr.vars[0] == kNoIntegerVariable || + expr.vars[1] == kNoIntegerVariable) { + continue; + } + DCHECK(expr.IsCanonicalized()); + DCHECK_GT(expr.coeffs[0], 0); + DCHECK_GT(expr.coeffs[1], 0); + const IntegerValue lb = -lin2_bounds_.LevelZeroUpperBound(NegationOf(i)); + for (int i = 0; i < 2; ++i) { + // Let's look for next_interval.start >= interval.end. + const IntegerVariable var = expr.vars[0]; + const IntegerVariable next_var = expr.vars[1]; + if (NegationOf(var) >= intervals_by_end_var_.size() || + next_var >= intervals_by_start_var_.size()) { + continue; + } + for (const IntervalVariable interval_var : + intervals_by_end_var_[NegationOf(var)]) { + const AffineExpression interval_end = intervals_.End(interval_var); + for (const IntervalVariable next_interval_var : + intervals_by_start_var_[next_var]) { + if (next_interval_var == interval_var) continue; + const AffineExpression next_interval_start = + intervals_.Start(next_interval_var); + const std::optional delta_t = + expr.GetDifferenceLowerBound(lb, next_interval_start, + interval_end); + if (delta_t.has_value() && delta_t.value() >= 0) { + interval_precedences_builder.Add(interval_var, next_interval_var); + } + } + } + std::swap(expr.vars[0], expr.vars[1]); + std::swap(expr.coeffs[0], expr.coeffs[1]); + } + for (int i = 0; i < 2; ++i) { + // Let's look for next_interval.start > interval.start (the above code + // can fail to detect some precedences, depending on how the scheduling + // problem is modeled and/or presolved). + const IntegerVariable var = expr.vars[0]; + const IntegerVariable next_var = expr.vars[1]; + if (next_var >= intervals_by_start_var_.size() || + NegationOf(var) >= intervals_by_start_var_.size()) { + continue; + } + for (const IntervalVariable interval_var : + intervals_by_start_var_[NegationOf(var)]) { + const AffineExpression interval_start = + intervals_.Start(interval_var); + for (const IntervalVariable next_interval_var : + intervals_by_start_var_[next_var]) { + if (next_interval_var == interval_var) continue; + const AffineExpression next_interval_start = + intervals_.Start(next_interval_var); + const std::optional delta_t = + expr.GetDifferenceLowerBound(lb, next_interval_start, + interval_start); + if (delta_t.has_value() && delta_t > 0) { + interval_precedences_builder.Add(interval_var, next_interval_var); + } + } + } + std::swap(expr.vars[0], expr.vars[1]); + std::swap(expr.coeffs[0], expr.coeffs[1]); + } + } + + // Add artificial precedences to fix optional intervals before non-optional + // ones which share the same start variable. A mandatory interval can have + // several optional variants, with exactly one being present. In this case + // the above code might detect precedences only between the optional + // intervals. To avoid conflicts in the heuristics below, it is better to + // fix the optional intervals first. + for (int i = 0; i < intervals_by_start_var_.size(); ++i) { + const IntegerVariable var(i); + for (const IntervalVariable interval_var : intervals_by_start_var_[var]) { + for (const IntervalVariable other_interval_var : + intervals_by_start_var_[var]) { + if (interval_var == other_interval_var) continue; + if (!intervals_.IsOptional(interval_var) && + intervals_.IsOptional(other_interval_var)) { + interval_precedences_builder.Add(other_interval_var, interval_var); + } + } + } + } + + CompactVectorVector + interval_precedences; + interval_precedences.ResetFromBuilder(interval_precedences_builder, + intervals_.NumIntervals()); + for (int i = 0; i < interval_precedences.size(); ++i) { + interval_precedences.SortAndRemoveDuplicateValues(IntervalVariable(i)); + } + return interval_precedences; + } + + // Detects conditional precedences between intervals. Returns (J, Δt) pairs + // for each interval I such that if J.start >= I.end + Δt, then all the + // conditional precedences we know about should be satisfied. + CompactVectorVector> + DetectConditionalPrecedences() { + // For each pair of intervals (I, J), the maximum transition time from I to + // J in the conditional bounds of the form lit => J is after I. + absl::flat_hash_map, + IntegerValue> + max_transition_times; + for (int i = 0; i < conditional_bounds_.size(); ++i) { + Relation relation = conditional_bounds_.relation(i); + if (relation.expr.vars[0] == kNoIntegerVariable || + relation.expr.vars[1] == kNoIntegerVariable) { + continue; + } + // There are four ways to interpret 'relation' as a precedence constraint + // between intervals: terms can be swapped and variables can be negated. + for (int j = 0; j < 4; ++j) { + if (NegationOf(relation.expr.vars[0]) < intervals_by_end_var_.size() && + relation.expr.vars[1] < intervals_by_start_var_.size()) { + for (IntervalVariable interval : + intervals_by_end_var_[NegationOf(relation.expr.vars[0])]) { + for (IntervalVariable next_interval : + intervals_by_start_var_[relation.expr.vars[1]]) { + if (next_interval == interval) continue; + const std::optional transition_time = + relation.expr.GetDifferenceLowerBound( + relation.lhs, intervals_.Start(next_interval), + intervals_.End(interval)); + if (transition_time.has_value() && transition_time > 0) { + auto& current_max = + max_transition_times[{interval, next_interval}]; + current_max = std::max(current_max, *transition_time); + } + } + } + } + if (j == 0 || j == 2) { + std::swap(relation.expr.vars[0], relation.expr.vars[1]); + std::swap(relation.expr.coeffs[0], relation.expr.coeffs[1]); + } else if (j == 1) { + relation.expr.vars[0] = NegationOf(relation.expr.vars[0]); + relation.expr.vars[1] = NegationOf(relation.expr.vars[1]); + relation.lhs = -relation.lhs; + relation.rhs = -relation.rhs; + std::swap(relation.lhs, relation.rhs); + } + } + } + CompactVectorVectorBuilder> + conditional_transitions_builder; + for (const auto& [transition_time, value] : max_transition_times) { + conditional_transitions_builder.Add(transition_time.first, + {transition_time.second, value}); + } + CompactVectorVector> + conditional_transitions; + conditional_transitions.ResetFromBuilder(conditional_transitions_builder, + intervals_.NumIntervals()); + return conditional_transitions; + } + + private: + const IntervalsRepository& intervals_; + const Linear2Indices& lin2_indices_; + const RootLevelLinear2Bounds& lin2_bounds_; + const ConditionalLinear2Bounds& conditional_bounds_; + + CompactVectorVector + intervals_by_start_var_; + CompactVectorVector intervals_by_end_var_; +}; + +} // namespace + class SchedulingSearchHeuristicHelper { public: explicit SchedulingSearchHeuristicHelper(Model* model) - : assignment_(model->GetOrCreate()->Assignment()), + : fixed_search_(model->GetOrCreate()->search_branching() == + SatParameters::FIXED_SEARCH), + assignment_(model->GetOrCreate()->Assignment()), repo_(model->GetOrCreate()), heuristic_(model->GetOrCreate()), watcher_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), random_(model->GetOrCreate()), rev_int_repo_(model->GetOrCreate()) { - // To avoid to scan already fixed intervals, we use a simple reversible int. - const int num_intervals = repo_->NumIntervals(); - intervals_.resize(num_intervals); - cached_start_mins_.resize(num_intervals); - for (IntervalVariable i(0); i < num_intervals; ++i) { - intervals_[i.value()] = i; + if (fixed_search_) { + IntervalPrecedencesDetector precedences_detector(model); + successors_ = precedences_detector.DetectPrecedences(); + successors_transition_times_ = + precedences_detector.DetectConditionalPrecedences(); } // Fix the max size of random choices. @@ -508,23 +737,33 @@ class SchedulingSearchHeuristicHelper { top_decisions.reserve(randomization_size_); top_decisions.resize(1); - // Save rev_fixed before we modify it. - rev_int_repo_->SaveState(&rev_fixed_); + // Find out if there was any backtrack since the last call, + // or since the beginning of the search. + const bool backtrack_since_last_call = !rev_is_in_dive_; + if (backtrack_since_last_call) { + RecomputePredecessorCounts(); + } + + if (!first_call_ && backtrack_since_last_call) { + no_backtrack_since_start_ = false; + cached_start_mins_.assign(repo_->NumIntervals(), kMinIntegerValue); + } + first_call_ = false; + const bool use_first_solution_heuristic = + fixed_search_ && no_backtrack_since_start_; + + for (int index = 0; index < intervals_with_only_fixed_predecessors_.size(); + ++index) { + const IntervalVariable interval = + intervals_with_only_fixed_predecessors_[index]; - // TODO(user): we should also precompute fixed precedences and only fix - // interval that have all their predecessos fixed. - const int num_intervals = intervals_.size(); - for (int i = rev_fixed_; i < num_intervals; ++i) { const ToSchedule& worst = top_decisions.back(); - if (rev_is_in_dive_ && cached_start_mins_[i] > worst.start_min) { + if (cached_start_mins_[interval] > worst.start_min) { continue; } - const IntervalVariable interval = intervals_[i]; if (repo_->IsAbsent(interval)) { - std::swap(intervals_[i], intervals_[rev_fixed_]); - std::swap(cached_start_mins_[i], cached_start_mins_[rev_fixed_]); - ++rev_fixed_; + ProcessAbsentOrFixedInterval(index); continue; } @@ -532,9 +771,7 @@ class SchedulingSearchHeuristicHelper { const AffineExpression end = repo_->End(interval); if (repo_->IsPresent(interval) && integer_trail_->IsFixed(start) && integer_trail_->IsFixed(end)) { - std::swap(intervals_[i], intervals_[rev_fixed_]); - std::swap(cached_start_mins_[i], cached_start_mins_[rev_fixed_]); - ++rev_fixed_; + ProcessAbsentOrFixedInterval(index); continue; } @@ -551,7 +788,15 @@ class SchedulingSearchHeuristicHelper { candidate.start_min = integer_trail_->LowerBound(start); candidate.start_max = integer_trail_->UpperBound(start); } - cached_start_mins_[i] = candidate.start_min; + if (use_first_solution_heuristic && + cached_start_mins_[interval] > candidate.start_min) { + IntegerValue delta_t = + cached_start_mins_[interval] - candidate.start_min; + candidate.start_min += delta_t; + candidate.start_max += delta_t; + } else { + cached_start_mins_[interval] = candidate.start_min; + } if (top_decisions.size() < randomization_size_ || candidate.MightBeBetter(worst)) { // Finish filling candidate. @@ -560,6 +805,7 @@ class SchedulingSearchHeuristicHelper { // to time. This is needed to never pick the "artificial" makespan // interval at the end in priority compared to intervals that still // need to be scheduled. + candidate.interval = interval; candidate.start = start; candidate.end = end; candidate.presence = repo_->IsOptional(interval) @@ -604,7 +850,8 @@ class SchedulingSearchHeuristicHelper { // Use the next_decision_override to fix in turn all the variables from // the selected interval. int num_times = 0; - heuristic_->next_decision_override = [this, best, num_times]() mutable { + heuristic_->next_decision_override = [this, use_first_solution_heuristic, + best, num_times]() mutable { if (++num_times > 5) { // We have been trying to fix this interval for a while. Do we miss // some propagation? In any case, try to see if the heuristic above @@ -628,7 +875,19 @@ class SchedulingSearchHeuristicHelper { // We assume that start_min is propagated by now. if (!integer_trail_->IsFixed(best.start)) { const IntegerValue start_min = integer_trail_->LowerBound(best.start); - VLOG(3) << "start == " << start_min; + const IntegerValue cached_start_min = cached_start_mins_[best.interval]; + VLOG(3) << "start == " << start_min + << " cached_start_min == " << cached_start_min; + if (use_first_solution_heuristic && cached_start_min > start_min) { + if (cached_start_min <= integer_trail_->UpperBound(best.start)) { + return BooleanOrIntegerLiteral( + best.start.GreaterOrEqual(cached_start_min)); + } else { + // Our heuristic gave us a decision that is currently false! Lets + // fall back to other heuristic until we are called again. + return BooleanOrIntegerLiteral(); + } + } return BooleanOrIntegerLiteral(best.start.LowerOrEqual(start_min)); } @@ -641,7 +900,15 @@ class SchedulingSearchHeuristicHelper { // Everything is fixed, detach the override. const IntegerValue start = integer_trail_->LowerBound(best.start); - VLOG(2) << "Fixed @[" << start << "," + const IntegerValue end = integer_trail_->LowerBound(best.end); + if (use_first_solution_heuristic) { + for (const auto& [successor, transition_time] : + successors_transition_times_[best.interval]) { + cached_start_mins_[successor] = + std::max(cached_start_mins_[successor], end + transition_time); + } + } + VLOG(2) << "Fixed " << best.interval << " @[" << start << "," << integer_trail_->LowerBound(best.end) << "]" << (best.presence != kNoLiteralIndex ? absl::StrCat(" presence=", @@ -658,6 +925,7 @@ class SchedulingSearchHeuristicHelper { private: struct ToSchedule { + IntervalVariable interval; // Variable to fix. LiteralIndex presence = kNoLiteralIndex; AffineExpression start; @@ -686,6 +954,58 @@ class SchedulingSearchHeuristicHelper { } }; + void RecomputePredecessorCounts() { + const int num_intervals = repo_->NumIntervals(); + num_non_fixed_predecessors_.assign(num_intervals, 0); + if (!successors_.empty()) { + for (IntervalVariable i(0); i < num_intervals; ++i) { + if (IntervalIsAbsentOrFixed(i)) continue; + for (const IntervalVariable j : successors_[i]) { + num_non_fixed_predecessors_[j]++; + } + } + } + intervals_with_only_fixed_predecessors_.clear(); + for (IntervalVariable i(0); i < num_intervals; ++i) { + if (IntervalIsAbsentOrFixed(i)) continue; + if (num_non_fixed_predecessors_[i] == 0) { + intervals_with_only_fixed_predecessors_.push_back(i); + } + } + cached_start_mins_.assign(num_intervals, kMinIntegerValue); + } + + // Removes the interval at the given index in + // intervals_with_only_fixed_predecessors_ and adds the intervals whose + // predecessors are now all fixed or absent. Decrements interval_index to + // account for the removed interval. + void ProcessAbsentOrFixedInterval(int& interval_index) { + const IntervalVariable i = + intervals_with_only_fixed_predecessors_[interval_index]; + std::swap(intervals_with_only_fixed_predecessors_[interval_index], + intervals_with_only_fixed_predecessors_.back()); + intervals_with_only_fixed_predecessors_.pop_back(); + interval_index--; + if (successors_.empty()) return; + for (const IntervalVariable j : successors_[i]) { + DCHECK_GT(num_non_fixed_predecessors_[j], 0); + num_non_fixed_predecessors_[j]--; + if (num_non_fixed_predecessors_[j] == 0) { + if (IntervalIsAbsentOrFixed(j)) continue; + DCHECK( + !absl::c_linear_search(intervals_with_only_fixed_predecessors_, j)); + intervals_with_only_fixed_predecessors_.push_back(j); + } + } + } + + bool IntervalIsAbsentOrFixed(IntervalVariable i) const { + return repo_->IsAbsent(i) || + (repo_->IsPresent(i) && integer_trail_->IsFixed(repo_->Start(i)) && + integer_trail_->IsFixed(repo_->End(i))); + } + + const bool fixed_search_; const VariablesAssignment& assignment_; IntervalsRepository* repo_; SearchHeuristics* heuristic_; @@ -694,11 +1014,31 @@ class SchedulingSearchHeuristicHelper { ModelRandomGenerator* random_; RevIntRepository* rev_int_repo_; - int rev_fixed_ = 0; bool rev_is_in_dive_ = false; + bool first_call_ = true; + bool no_backtrack_since_start_ = true; int randomization_size_ = 1; - std::vector intervals_; - std::vector cached_start_mins_; + + // For each interval I, all the intervals J which can only be fixed once I is + // fixed, or absent (this can include precedences which are not in the + // original problem, such as artificial precedences between optional and + // non-optional intervals). Empty if fixed_search_ is false. + CompactVectorVector successors_; + // For each interval I, of list of (J, Δt) pairs + // such that if J.start >= I.end + Δt, then all the conditional precedences we + // know about should be satisfied. Empty if fixed_search_ is false. + CompactVectorVector> + successors_transition_times_; + + // For each interval I, the number of predecessors of I which are not absent + // and not yet fixed. + util_intops::StrongVector num_non_fixed_predecessors_; + // The intervals whose predecessors are all fixed or absent, and which are not + // themselves fixed or absent. + std::vector intervals_with_only_fixed_predecessors_; + + util_intops::StrongVector cached_start_mins_; }; // A simple heuristic for scheduling models. @@ -1619,6 +1959,30 @@ SatSolver::Status IntegerSearchHelper::SolveIntegerProblem() { // completion. So we cannot report a feasible solution. if (time_limit_->LimitReached()) return SatSolver::LIMIT_REACHED; if (decision == kNoLiteralIndex) { + // In FIXED_SEARCH and until the first backtrack, the linear propagator + // does not fully propagate in order to be faster. This can cause issues + // once all the decisions are taken, because there might still be some + // unassigned variables, or violated constraints. So do one round of full + // propagation before accepting that solution. + if (parameters_.search_branching() == SatParameters::FIXED_SEARCH) { + LinearPropagator* linear_propagator = + model_->Mutable(); + if (linear_propagator != nullptr && + (!linear_propagator->PropagateAll() || + !sat_solver_->FinishPropagation())) { + // Should not happen, but restart if it does (PropagateAll() disables + // the incomplete propagation in the linear propagator, and is a no-op + // after that; hence this can only happen at most once). + LOG(DFATAL) + << "Should not happen: linear propagator PropagateAll() failed"; + sat_solver_->Backtrack(0); + if (!sat_solver_->FinishPropagation()) { + return sat_solver_->UnsatStatus(); + } + continue; + } + } + // Save the current polarity of all Booleans in the solution. It will be // followed for the next SAT decisions. This is known to be a good policy // for optimization problem. Note that for decision problem we don't care @@ -1729,7 +2093,7 @@ ContinuousProber::ContinuousProber(const CpModelProto& model_proto, prober_(model->GetOrCreate()), shared_response_manager_(model->Mutable()), shared_bounds_manager_(model->Mutable()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), active_limit_(parameters_.shaving_search_deterministic_time()) { auto* mapping = model_->GetOrCreate(); absl::flat_hash_set visited; @@ -1774,12 +2138,6 @@ SatSolver::Status ContinuousProber::Probe() { if (!sat_solver_->ResetToLevelZero()) return SatSolver::INFEASIBLE; while (!time_limit_->LimitReached()) { - if (parameters_.use_sat_inprocessing() && - !inprocessing_->InprocessingRound()) { - sat_solver_->NotifyThatModelIsUnsat(); - return sat_solver_->UnsatStatus(); - } - // Store current statistics to detect an iteration without any improvement. const int64_t initial_num_literals_fixed = prober_->num_new_literals_fixed(); @@ -1832,6 +2190,9 @@ SatSolver::Status ContinuousProber::Probe() { const BooleanVariable& bool_var = bool_vars_[current_bool_var_]; if (assignment.VariableIsAssigned(bool_var)) continue; + if (binary_implication_graph_->IsRedundant(Literal(bool_var, true))) { + continue; + } const auto [_, inserted] = probed_bool_vars_.insert(bool_var); if (!inserted) continue; @@ -1939,10 +2300,10 @@ SatSolver::Status ContinuousProber::Probe() { for (; random_pair_of_bool_vars_probed_ < 10000; ++random_pair_of_bool_vars_probed_) { const BooleanVariable bv1 = - bool_vars_[absl::Uniform(*random_, 0, bool_vars_.size())]; + bool_vars_[absl::Uniform(random_, 0, bool_vars_.size())]; if (assignment.VariableIsAssigned(bv1)) continue; const BooleanVariable bv2 = - bool_vars_[absl::Uniform(*random_, 0, bool_vars_.size())]; + bool_vars_[absl::Uniform(random_, 0, bool_vars_.size())]; if (assignment.VariableIsAssigned(bv2) || bv1 == bv2) { continue; } @@ -1970,15 +2331,15 @@ SatSolver::Status ContinuousProber::Probe() { for (; random_triplet_of_bool_vars_probed_ < loop_limit; ++random_triplet_of_bool_vars_probed_) { const BooleanVariable bv1 = - bool_vars_[absl::Uniform(*random_, 0, bool_vars_.size())]; + bool_vars_[absl::Uniform(random_, 0, bool_vars_.size())]; if (assignment.VariableIsAssigned(bv1)) continue; const BooleanVariable bv2 = - bool_vars_[absl::Uniform(*random_, 0, bool_vars_.size())]; + bool_vars_[absl::Uniform(random_, 0, bool_vars_.size())]; if (assignment.VariableIsAssigned(bv2) || bv1 == bv2) { continue; } const BooleanVariable bv3 = - bool_vars_[absl::Uniform(*random_, 0, bool_vars_.size())]; + bool_vars_[absl::Uniform(random_, 0, bool_vars_.size())]; if (assignment.VariableIsAssigned(bv3) || bv1 == bv3 || bv2 == bv3) { continue; } @@ -2081,6 +2442,14 @@ SatSolver::Status ContinuousProber::PeriodicSyncAndCheck() { return SatSolver::INFEASIBLE; } } + + // Run inprocessing. Note that this does nothing if not enough dtime was + // spent since the last call, so it is okay to call it relatively often. + if (parameters_.use_sat_inprocessing() && + !inprocessing_->InprocessingRound()) { + sat_solver_->NotifyThatModelIsUnsat(); + return sat_solver_->UnsatStatus(); + } } return SatSolver::FEASIBLE; diff --git a/ortools/sat/integer_search.h b/ortools/sat/integer_search.h index c94d656740e..0a958024a03 100644 --- a/ortools/sat/integer_search.h +++ b/ortools/sat/integer_search.h @@ -44,7 +44,6 @@ #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" -#include "ortools/sat/util.h" #include "ortools/util/strong_integers.h" #include "ortools/util/time_limit.h" @@ -381,7 +380,7 @@ class ContinuousProber { Prober* prober_; SharedResponseManager* shared_response_manager_; SharedBoundsManager* shared_bounds_manager_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; // Statistics. int64_t num_literals_probed_ = 0; diff --git a/ortools/sat/integer_test.cc b/ortools/sat/integer_test.cc index d518485e728..ec5dcd1bc07 100644 --- a/ortools/sat/integer_test.cc +++ b/ortools/sat/integer_test.cc @@ -27,6 +27,7 @@ #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/base/log_severity.h" +#include "ortools/base/types.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/integer_search.h" #include "ortools/sat/model.h" @@ -137,17 +138,15 @@ TEST(NegationOfTest, VectorArgument) { } TEST(IntegerValue, NegatedCannotOverflow) { - EXPECT_GT(kMinIntegerValue - 1, std::numeric_limits::min()); + EXPECT_GT(kMinIntegerValue - 1, kint64min); } TEST(IntegerLiteral, OverflowValueAreCapped) { const IntegerVariable var(0); EXPECT_EQ(IntegerLiteral::GreaterOrEqual(var, kMaxIntegerValue + 1), - IntegerLiteral::GreaterOrEqual( - var, IntegerValue(std::numeric_limits::max()))); + IntegerLiteral::GreaterOrEqual(var, IntegerValue(kint64max))); EXPECT_EQ(IntegerLiteral::LowerOrEqual(var, kMinIntegerValue - 1), - IntegerLiteral::LowerOrEqual( - var, IntegerValue(std::numeric_limits::min()))); + IntegerLiteral::LowerOrEqual(var, IntegerValue(kint64min))); } TEST(IntegerLiteral, NegatedIsIdempotent) { @@ -163,13 +162,11 @@ TEST(IntegerLiteral, NegatedIsIdempotent) { // A bound difference of exactly kint64max is ok. TEST(IntegerTrailDeathTest, LargeVariableDomain) { Model model; - model.Add(NewIntegerVariable(-3, std::numeric_limits::max() - 3)); + model.Add(NewIntegerVariable(-3, kint64max - 3)); if (DEBUG_MODE) { // But one of kint64max + 1 cause a check fail in debug. - EXPECT_DEATH(model.Add(NewIntegerVariable( - -3, std::numeric_limits::max() - 2)), - ""); + EXPECT_DEATH(model.Add(NewIntegerVariable(-3, kint64max - 2)), ""); } } diff --git a/ortools/sat/java/BUILD.bazel b/ortools/sat/java/BUILD.bazel index 1b07333b3ac..3aabc932bb6 100644 --- a/ortools/sat/java/BUILD.bazel +++ b/ortools/sat/java/BUILD.bazel @@ -45,9 +45,7 @@ java_wrap_cc( name = "cp_model_proto", srcs = ["cp_model_proto.swig"], package = "com.google.ortools.sat", - swig_includes = [ - "//ortools/base:base_swig", - ], + swig_includes = ["//ortools/base:base_swig"], visibility = ["//visibility:public"], deps = [ "//ortools/port:proto_utils", diff --git a/ortools/sat/lb_tree_search.cc b/ortools/sat/lb_tree_search.cc index 96951583996..6ff46d0148e 100644 --- a/ortools/sat/lb_tree_search.cc +++ b/ortools/sat/lb_tree_search.cc @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "ortools/base/log_severity.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" @@ -54,7 +55,7 @@ namespace sat { LbTreeSearch::LbTreeSearch(Model* model) : name_(model->Name()), time_limit_(model->GetOrCreate()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), sat_solver_(model->GetOrCreate()), integer_encoder_(model->GetOrCreate()), trail_(model->GetOrCreate()), @@ -303,11 +304,11 @@ void LbTreeSearch::MarkBranchAsInfeasible(Node& node, bool true_branch) { if (true_branch) { node.UpdateTrueObjective(kMaxIntegerValue); MarkSubtreeAsDeleted(node.true_child); - node.true_child = NodeIndex(std::numeric_limits::max()); + node.true_child = NodeIndex(kint32max); } else { node.UpdateFalseObjective(kMaxIntegerValue); MarkSubtreeAsDeleted(node.false_child); - node.false_child = NodeIndex(std::numeric_limits::max()); + node.false_child = NodeIndex(kint32max); } } diff --git a/ortools/sat/lb_tree_search.h b/ortools/sat/lb_tree_search.h index 02e5536d987..c4bb8e928a8 100644 --- a/ortools/sat/lb_tree_search.h +++ b/ortools/sat/lb_tree_search.h @@ -25,6 +25,7 @@ #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/glop/variables_info.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" @@ -112,8 +113,8 @@ class LbTreeSearch { IntegerValue false_objective; // Points to adjacent nodes in the tree. Large if no connection. - NodeIndex true_child = NodeIndex(std::numeric_limits::max()); - NodeIndex false_child = NodeIndex(std::numeric_limits::max()); + NodeIndex true_child = NodeIndex(kint32max); + NodeIndex false_child = NodeIndex(kint32max); // Indicates if this nodes was removed from the tree. bool is_deleted = false; @@ -187,7 +188,7 @@ class LbTreeSearch { // Model singleton class used here. const std::string name_; TimeLimit* time_limit_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; SatSolver* sat_solver_; IntegerEncoder* integer_encoder_; Trail* trail_; diff --git a/ortools/sat/linear_constraint.cc b/ortools/sat/linear_constraint.cc index c3ad453470e..3b25c94585d 100644 --- a/ortools/sat/linear_constraint.cc +++ b/ortools/sat/linear_constraint.cc @@ -29,6 +29,7 @@ #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/sat_base.h" @@ -424,7 +425,7 @@ bool ValidateLinearConstraintForOverflow(const LinearConstraint& constraint, negative_sum = CapAdd(negative_sum, std::min(int64_t{0}, min_prod)); } - const int64_t limit = std::numeric_limits::max(); + const int64_t limit = kint64max; if (positive_sum >= limit) return false; if (negative_sum <= -limit) return false; if (CapSub(positive_sum, negative_sum) >= limit) return false; diff --git a/ortools/sat/linear_constraint_manager.cc b/ortools/sat/linear_constraint_manager.cc index 1941e8f65a4..9e0b6350215 100644 --- a/ortools/sat/linear_constraint_manager.cc +++ b/ortools/sat/linear_constraint_manager.cc @@ -36,6 +36,7 @@ #include "absl/types/span.h" #include "ortools/base/hash.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/glop/variables_info.h" #include "ortools/lp_data/lp_types.h" #include "ortools/sat/debug_solution.h" @@ -698,7 +699,7 @@ void LinearConstraintManager::FillDerivedFields(ConstraintInfo* info) { info->constraint.lb = std::max(min_sum, info->constraint.lb); info->constraint.ub = std::min(max_sum, info->constraint.ub); CHECK_NE(CapSub(info->constraint.ub.value(), info->constraint.lb.value()), - std::numeric_limits::max()); + kint64max); info->lb_is_trivial = min_sum >= info->constraint.lb; info->ub_is_trivial = max_sum <= info->constraint.ub; } diff --git a/ortools/sat/linear_constraint_test.cc b/ortools/sat/linear_constraint_test.cc index 7164bd84eb6..4b3d8fd0184 100644 --- a/ortools/sat/linear_constraint_test.cc +++ b/ortools/sat/linear_constraint_test.cc @@ -24,6 +24,7 @@ #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" @@ -77,9 +78,8 @@ TEST(ComputeInfinityNormTest, BasicTest) { } { LinearConstraintBuilder constraint; - constraint.AddTerm(y, IntegerValue(std::numeric_limits::max())); - EXPECT_EQ(IntegerValue(std::numeric_limits::max()), - ComputeInfinityNorm(constraint.Build())); + constraint.AddTerm(y, IntegerValue(kint64max)); + EXPECT_EQ(IntegerValue(kint64max), ComputeInfinityNorm(constraint.Build())); } } diff --git a/ortools/sat/linear_programming_constraint.cc b/ortools/sat/linear_programming_constraint.cc index 7f8247a688d..656a6377b8e 100644 --- a/ortools/sat/linear_programming_constraint.cc +++ b/ortools/sat/linear_programming_constraint.cc @@ -41,6 +41,7 @@ #include "ortools/algorithms/binary_search.h" #include "ortools/base/log_severity.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/glop/parameters.pb.h" #include "ortools/glop/revised_simplex.h" #include "ortools/glop/status.h" @@ -49,6 +50,7 @@ #include "ortools/lp_data/lp_types.h" #include "ortools/lp_data/scattered_vector.h" #include "ortools/lp_data/sparse.h" +#include "ortools/sat/cp_model_checker.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cuts.h" #include "ortools/sat/implied_bounds.h" @@ -97,9 +99,7 @@ void ScatteredIntegerVector::ClearAndResize(int size) { bool ScatteredIntegerVector::Add(glop::ColIndex col, IntegerValue value) { const int64_t add = CapAdd(value.value(), dense_vector_[col].value()); - if (add == std::numeric_limits::min() || - add == std::numeric_limits::max()) - return false; + if (add == kint64min || add == kint64max) return false; dense_vector_[col] = IntegerValue(add); if (is_sparse_ && is_zeros_[col]) { is_zeros_[col] = false; @@ -291,7 +291,8 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( objective_definition_(model->GetOrCreate()), shared_stats_(model->GetOrCreate()), shared_response_manager_(model->GetOrCreate()), - random_(model->GetOrCreate()), + cp_model_mapping_(model->GetOrCreate()), + random_(*model->GetOrCreate()), symmetrizer_(model->GetOrCreate()), linear_propagator_(model->GetOrCreate()), cover_cut_helper_(model), @@ -314,7 +315,7 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( } simplex_.SetParameters(simplex_params_); // Warning: SetRandom() must be called after SetParameters(). - simplex_.SetRandom(*random_); + simplex_.SetRandom(random_); if (parameters_.search_branching() == SatParameters::LP_SEARCH) { compute_reduced_cost_averages_ = true; } @@ -327,6 +328,8 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( // Initialize the IntegerVariable -> ColIndex mapping. CHECK(std::is_sorted(vars.begin(), vars.end())); + int num_proto_variables = 0; + // TODO(user): We shouldn't need to add variable from the orbit here in the // presence of symmetry. However they can still appear in cut, so it is a // bit tricky and require some refactoring to be tried. @@ -336,6 +339,11 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( implied_bounds_processor_.AddLpVariable(positive_variable); (*dispatcher_)[positive_variable] = this; + if (cp_model_mapping_->GetProtoVariableFromIntegerVariable( + positive_variable) != -1) { + ++num_proto_variables; + } + if (!symmetrizer_->AppearInFoldedProblem(positive_variable)) continue; integer_variables_.push_back(positive_variable); @@ -345,6 +353,11 @@ LinearProgrammingConstraint::LinearProgrammingConstraint( ++col; } + if (num_proto_variables == cp_model_mapping_->NumProtoVariables() && + parameters_.linearization_level() > 1) { + integer_solution_are_likely_feasible_ = true; + } + // Complete the extended variables with the orbit afterwards. if (symmetrizer_->HasSymmetry()) { for (const IntegerVariable var : integer_variables_) { @@ -578,17 +591,13 @@ bool LinearProgrammingConstraint::CreateLpFromConstraintManager() { // Since we used level zero bounds above, fix them. UpdateBoundsOfLpVariables(); - // Set the information for the step to polish the LP basis. All our variables - // are integer, but for now, we just try to minimize the fractionality of the - // binary variables. + // Set the integrality information for the step to polish the LP basis. if (parameters_.polish_lp_solution()) { + simplex_params_.set_dual_polish(true); + simplex_.ClearIntegralityScales(); const int num_vars = integer_variables_.size(); for (int i = 0; i < num_vars; ++i) { - const IntegerVariable cp_var = integer_variables_[i]; - const IntegerValue lb = integer_trail_->LevelZeroLowerBound(cp_var); - const IntegerValue ub = integer_trail_->LevelZeroUpperBound(cp_var); - if (lb != 0 || ub != 1) continue; simplex_.SetIntegralityScale( glop::ColIndex(i), 1.0 / scaler_.VariableScalingFactor(glop::ColIndex(i))); @@ -1090,6 +1099,42 @@ bool LinearProgrammingConstraint::SolveLp() { } } + // When we have a integer solution, it can be way faster to just check + // feasibility and report it rather than branching one variable at the + // time and propagating until we get to the solution. + // + // And also sometime, we don't even do the branching! like in lb_tree_search + // or even probing. + if (parameters_.exploit_integer_lp_solution() && + !parameters_.enumerate_all_solutions() && lp_solution_is_integer_ && + integer_solution_are_likely_feasible_ && + cp_model_mapping_->ModelProto() != nullptr) { + const int num_proto_vars = cp_model_mapping_->NumProtoVariables(); + absl::Span proto_vars = + cp_model_mapping_->GetVariableMapping(); + std::vector solution(num_proto_vars, 0); + for (int i = 0; i < num_proto_vars; ++i) { + const IntegerVariable var = proto_vars[i]; + CHECK_NE(var, kNoIntegerVariable); + solution[i] = std::round(expanded_lp_solution_[var]); + } + if (SolutionIsFeasible(*cp_model_mapping_->ModelProto(), solution)) { + // TODO(user): Shall we report all such solution? hopefully there + // shouldn't be too many. + // + // TODO(user): This should be the best reachable solution in that + // subtree, so we should likely backtrack right away. + shared_response_manager_->NewSolution( + solution, absl::StrCat(model_->Name(), " (lp)"), model_); + } else { + // We disable this if our integer solution do not seems to be feasible. + // TODO(user): find better heuristic? + if (++num_infeasible_integer_lp_solutions_ > 100) { + integer_solution_are_likely_feasible_ = false; + } + } + } + if (lp_solution_level_ == 0) { level_zero_lp_solution_ = lp_solution_; } @@ -1517,8 +1562,7 @@ bool LinearProgrammingConstraint::AddCutFromConstraints( bool LinearProgrammingConstraint::PostprocessAndAddCut( const std::string& name, const std::string& info, IntegerVariable first_slack, const CutData& cut) { - if (cut.rhs > absl::int128(std::numeric_limits::max()) || - cut.rhs < absl::int128(std::numeric_limits::min())) { + if (cut.rhs > absl::int128(kint64max) || cut.rhs < absl::int128(kint64min)) { VLOG(2) << "RHS overflow " << name << " " << info; ++num_cut_overflows_; return false; @@ -1868,7 +1912,7 @@ void LinearProgrammingConstraint::AddMirCuts() { // entries we process. We randomize the base_rows so that on the next calls // we do not do exactly the same if we can't process many base row. int64_t dtime_num_entries = 0; - std::shuffle(base_rows.begin(), base_rows.end(), *random_); + std::shuffle(base_rows.begin(), base_rows.end(), random_); std::vector weights; util_intops::StrongVector used_rows; @@ -1946,7 +1990,7 @@ void LinearProgrammingConstraint::AddMirCuts() { if (col_candidates.empty()) break; const ColIndex var_to_eliminate = - col_candidates[WeightedPick(weights, *random_)]; + col_candidates[WeightedPick(weights, random_)]; // What rows can we add to eliminate var_to_eliminate? std::vector possible_rows; @@ -1989,7 +2033,7 @@ void LinearProgrammingConstraint::AddMirCuts() { if (possible_rows.empty()) break; const RowIndex row_to_combine = - possible_rows[WeightedPick(weights, *random_)]; + possible_rows[WeightedPick(weights, random_)]; // Find the coefficient of the variable to eliminate. IntegerValue to_combine_coeff = 0; @@ -2027,8 +2071,7 @@ void LinearProgrammingConstraint::AddMirCuts() { } if (CapAdd(CapProd(max_magnitude.value(), std::abs(mult1.value())), CapProd(infinity_norms_[row_to_combine].value(), - std::abs(mult2.value()))) == - std::numeric_limits::max()) { + std::abs(mult2.value()))) == kint64max) { break; } @@ -2151,12 +2194,15 @@ bool LinearProgrammingConstraint::Propagate() { // if the deterministic time spent after the last level zero "solve" is lower // than the effort spent on that last solve. // + // We use a really high level to disable this in some TEST only. + // // TODO(user): also use the logic of lp_at_level_zero_is_final_. If we don't // have new info, there is no reason to rerun it. harder to make sure we // don't miss anything though. const double dtime_at_function_start = time_limit_->GetElapsedDeterministicTime(); - if (trail_->CurrentDecisionLevel() == 0 && old_num_force == 0) { + if (trail_->CurrentDecisionLevel() == 0 && old_num_force == 0 && + parameters_.linearization_level() < 10) { const double interval = dtime_at_function_start - last_root_level_deterministic_time_; if (last_root_level_deterministic_duration_ > interval) { @@ -2202,7 +2248,7 @@ bool LinearProgrammingConstraint::Propagate() { } simplex_.SetParameters(simplex_params_); - simplex_.SetRandom(*random_); + simplex_.SetRandom(random_); if (!SolveLp()) return true; if (!AnalyzeLp()) return false; @@ -2343,7 +2389,7 @@ bool LinearProgrammingConstraint::ScalingCanOverflow( const double magnitude = std::abs(std::round(double_coeff * factor_as_double)); if (std::isnan(magnitude)) return true; - if (magnitude >= static_cast(std::numeric_limits::max())) { + if (magnitude >= static_cast(kint64max)) { return true; } @@ -2389,7 +2435,7 @@ void LinearProgrammingConstraint::ScaleMultipliers( // TODO(user): we currently do not support scaling down, so we just abort // if with a scaling of 1, we reach the overflow_cap. - const int64_t overflow_cap = std::numeric_limits::max(); + const int64_t overflow_cap = kint64max; if (ScalingCanOverflow(/*power=*/0, take_objective_into_account, lp_multipliers, overflow_cap)) { ++num_scaling_issues_; @@ -2702,9 +2748,14 @@ bool LinearProgrammingConstraint::PropagateExactLpReason() { if (objective_cp_is_part_of_lp_) { // The objective is part of the lp. // This should only happen for objective with a single term. - CHECK_EQ(integer_objective_.size(), 1); - CHECK_EQ(integer_objective_[0].first, mirror_lp_variable_[objective_cp_]); - CHECK_EQ(integer_objective_[0].second, IntegerValue(1)); + // + // Tricky: if that variable is fixed, we might have removed it from the + // integer_objective_ completely as we filter it. + if (!integer_objective_.empty()) { + CHECK_EQ(integer_objective_.size(), 1); + CHECK_EQ(integer_objective_[0].first, mirror_lp_variable_[objective_cp_]); + CHECK_EQ(integer_objective_[0].second, IntegerValue(1)); + } take_objective_into_account = false; } diff --git a/ortools/sat/linear_programming_constraint.h b/ortools/sat/linear_programming_constraint.h index f6395a62441..9373a4c5cf6 100644 --- a/ortools/sat/linear_programming_constraint.h +++ b/ortools/sat/linear_programming_constraint.h @@ -157,6 +157,7 @@ class LinearProgrammingConstraint : public PropagatorInterface, // The main objective variable should be equal to the linear sum of // the arguments passed to SetObjectiveCoefficient(). void SetMainObjectiveVariable(IntegerVariable ivar) { + CHECK(!integer_objective_.empty()); objective_cp_ = ivar; objective_cp_is_part_of_lp_ = false; for (const IntegerVariable var : integer_variables_) { @@ -376,9 +377,6 @@ class LinearProgrammingConstraint : public PropagatorInterface, ScatteredIntegerVector* scattered_vector, IntegerValue* upper_bound) const; - // Shortcut for an integer linear expression type. - using LinearExpression = std::vector>; - // Converts a dense representation of a linear constraint to a sparse one // expressed in terms of IntegerVariable. void ConvertToLinearConstraint( @@ -470,7 +468,7 @@ class LinearProgrammingConstraint : public PropagatorInterface, std::vector tmp_coeffs_; std::vector tmp_vars_; - LinearExpression integer_objective_; + std::vector> integer_objective_; IntegerValue integer_objective_offset_ = IntegerValue(0); IntegerValue objective_infinity_norm_ = IntegerValue(0); util_intops::StrongVector @@ -538,7 +536,8 @@ class LinearProgrammingConstraint : public PropagatorInterface, ObjectiveDefinition* objective_definition_; SharedStatistics* shared_stats_; SharedResponseManager* shared_response_manager_; - ModelRandomGenerator* random_; + CpModelMapping* cp_model_mapping_; + absl::BitGenRef random_; LinearConstraintSymmetrizer* symmetrizer_; LinearPropagator* linear_propagator_; @@ -648,6 +647,11 @@ class LinearProgrammingConstraint : public PropagatorInterface, // We might temporarily disable the LP propagation. bool enabled_ = true; + // We set that to true if all proto variable are in the LP relaxation and + // we are at a high enough relaxation level. + bool integer_solution_are_likely_feasible_ = false; + int num_infeasible_integer_lp_solutions_ = 0; + // Logic to throttle level zero calls. int64_t num_root_level_skips_ = 0; int64_t num_root_level_solves_ = 0; diff --git a/ortools/sat/linear_programming_constraint_test.cc b/ortools/sat/linear_programming_constraint_test.cc index 2df8a638700..8bd177a0b6c 100644 --- a/ortools/sat/linear_programming_constraint_test.cc +++ b/ortools/sat/linear_programming_constraint_test.cc @@ -15,6 +15,7 @@ #include +#include #include #include #include @@ -101,21 +102,31 @@ LPProblem GeneratePermutationProblem(int size) { std::vector> node_constraint_indices; std::vector> node_constraint_coefs; + std::vector randomized_left; + std::vector randomized_right; + for (int i = 0; i < size; i++) { + randomized_left.push_back(i); + randomized_right.push_back(i); + } + std::shuffle(randomized_left.begin(), randomized_left.end(), absl::BitGen()); + std::shuffle(randomized_right.begin(), randomized_right.end(), + absl::BitGen()); + // Left and right nodes are indexed by [0, size). // The edge (left, right) has number left * size + right. const std::vector ones(size, IntegerValue(1)); - for (int left = 0; left < size; left++) { + for (const int left : randomized_left) { std::vector indices; - for (int right = 0; right < size; right++) { + for (int right = 0; right < size; ++right) { indices.push_back(left * size + right); } node_constraint_indices.push_back(indices); node_constraint_coefs.push_back(ones); } - for (int right = 0; right < size; right++) { + for (const int right : randomized_right) { std::vector indices; - for (int left = 0; left < size; left++) { + for (int left = 0; left < size; ++left) { indices.push_back(left * size + right); } node_constraint_indices.push_back(indices); @@ -134,6 +145,8 @@ LPProblem GeneratePermutationProblem(int size) { int CountSolutionsOfLPProblemUsingSAT(const LPProblem& problem) { Model model; model.GetOrCreate()->set_add_lp_constraints_lazily(false); + model.GetOrCreate()->set_use_sat_inprocessing(false); + model.GetOrCreate()->set_linearization_level(10); std::vector cp_variables; const int num_cp_vars = problem.num_integer_vars(); diff --git a/ortools/sat/linear_propagation.cc b/ortools/sat/linear_propagation.cc index 2c2e9e8b194..cfe9ad922b0 100644 --- a/ortools/sat/linear_propagation.cc +++ b/ortools/sat/linear_propagation.cc @@ -17,7 +17,6 @@ #include #include -#include #include #include #include @@ -33,12 +32,14 @@ #include "absl/types/span.h" #include "ortools/base/log_severity.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/enforcement.h" #include "ortools/sat/integer.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" #include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" #include "ortools/sat/util.h" #include "ortools/util/bitset.h" @@ -50,6 +51,7 @@ namespace sat { LinearPropagator::LinearPropagator(Model* model) : trail_(model->GetOrCreate()), + sat_solver_(model->GetOrCreate()), integer_trail_(model->GetOrCreate()), enforcement_propagator_(model->GetOrCreate()), enforcement_helper_(model->GetOrCreate()), @@ -61,9 +63,12 @@ LinearPropagator::LinearPropagator(Model* model) precedences_(model->GetOrCreate()), lin2_indices_(model->GetOrCreate()), linear3_bounds_(model->GetOrCreate()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), shared_stats_(model->GetOrCreate()), watcher_id_(watcher_->Register(this)), + only_propagate_unit_linear_( + model->GetOrCreate()->search_branching() == + SatParameters::FIXED_SEARCH), order_(random_, time_limit_, [this](int id) { return GetVariables(infos_[id]); }) { // Note that we need this class always in sync. @@ -159,6 +164,29 @@ void LinearPropagator::OnVariableChange(IntegerVariable var, IntegerValue lb, SetPropagatedBy(var, id); order_.UpdateBound(var, lb); + + if (only_propagate_unit_linear_ && sat_solver_->num_failures() == 0) { + // Stop the propagation if `var` was propagated by a "non-unit" linear + // constraint (i.e., a linear constraint with more than one non-fixed + // variable). This avoids a quadratic number of propagations when fixing + // vars one by one when looking for a first solution with FIXED_SEARCH. + if (id != -1 && infos_[id].rev_size != 1) { + return; + } + // When we only propagate unit linear constraints, we need to activate again + // all constraints that contain a fixed variable, and this is both + // directions. + if (integer_trail_->UpperBound(var) == lb) { + AddVarConstraintsToQueue(NegationOf(var)); + } + } + + AddVarConstraintsToQueue(var); +} + +void LinearPropagator::AddVarConstraintsToQueue(IntegerVariable var) { + const int size = var_to_constraint_ids_[var].size(); + if (size == 0) return; Bitset64::View in_queue = in_queue_.view(); time_limit_->AdvanceDeterministicTime(static_cast(size) * 1e-9); for (const int id : var_to_constraint_ids_[var]) { @@ -312,6 +340,17 @@ bool LinearPropagator::Propagate() { return true; } +bool LinearPropagator::PropagateAll() { + if (!only_propagate_unit_linear_) { + return true; + } + only_propagate_unit_linear_ = false; + for (int i = 0; i < infos_.size(); ++i) { + AddToQueueIfNeeded(i); + } + return Propagate(); +} + // Adds a new constraint to the propagator. bool LinearPropagator::AddConstraint( absl::Span enforcement_literals, @@ -824,7 +863,7 @@ bool LinearPropagator::ReportConflictingCycle() { }); // Relax the linear reason if everything fit on an int64_t. - const absl::int128 limit{std::numeric_limits::max()}; + const absl::int128 limit{kint64max}; const absl::int128 slack = implied_lb - rhs_sum; if (slack > 1) { reason_coeffs_.clear(); diff --git a/ortools/sat/linear_propagation.h b/ortools/sat/linear_propagation.h index 20ea6aeea49..d3fb05400f4 100644 --- a/ortools/sat/linear_propagation.h +++ b/ortools/sat/linear_propagation.h @@ -18,7 +18,6 @@ #include #include -#include #include #include #include @@ -26,8 +25,10 @@ #include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/random/bit_gen_ref.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/enforcement.h" #include "ortools/sat/enforcement_helper.h" #include "ortools/sat/integer.h" @@ -35,8 +36,8 @@ #include "ortools/sat/model.h" #include "ortools/sat/precedences.h" #include "ortools/sat/sat_base.h" +#include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" -#include "ortools/sat/util.h" #include "ortools/util/bitset.h" #include "ortools/util/rev.h" #include "ortools/util/strong_integers.h" @@ -58,7 +59,7 @@ namespace sat { class ConstraintPropagationOrder { public: ConstraintPropagationOrder( - ModelRandomGenerator* random, TimeLimit* time_limit, + absl::BitGenRef random, TimeLimit* time_limit, std::function(int)> id_to_vars) : random_(random), time_limit_(time_limit), @@ -116,7 +117,7 @@ class ConstraintPropagationOrder { int best_id = 0; int best_num_vars = 0; - int best_degree = std::numeric_limits::max(); + int best_degree = kint32max; int64_t work_done = 0; const int size = ids_.size(); const auto var_has_entry = var_has_entry_.const_view(); @@ -205,7 +206,7 @@ class ConstraintPropagationOrder { } public: - ModelRandomGenerator* random_; + absl::BitGenRef random_; TimeLimit* time_limit_; std::function(int)> id_to_vars_func_; @@ -239,6 +240,11 @@ class LinearPropagator : public PropagatorInterface, bool Propagate() final; void SetLevel(int level) final; + // In FIXED_SEARCH and until the first backtrack, Propagate() does not fully + // propagate in order to be faster. In this case, completes the propagation + // and disables the "fast propagation" mode. Otherwise does nothing. + bool PropagateAll(); + std::string LazyReasonName() const override { return "LinearPropagator"; } // Adds a new constraint to the propagator. @@ -286,6 +292,7 @@ class LinearPropagator : public PropagatorInterface, // Called when the lower bound of a variable changed. The id is the constraint // id that caused this change or -1 if it comes from an external source. void OnVariableChange(IntegerVariable var, IntegerValue lb, int id); + void AddVarConstraintsToQueue(IntegerVariable var); // Returns false on conflict. ABSL_MUST_USE_RESULT bool PropagateOneConstraint(int id); @@ -317,6 +324,7 @@ class LinearPropagator : public PropagatorInterface, // External class needed. Trail* trail_; + SatSolver* sat_solver_; IntegerTrail* integer_trail_; EnforcementPropagator* enforcement_propagator_; EnforcementHelper* enforcement_helper_; @@ -327,7 +335,8 @@ class LinearPropagator : public PropagatorInterface, EnforcedLinear2Bounds* precedences_; Linear2Indices* lin2_indices_; Linear2BoundsFromLinear3* linear3_bounds_; - ModelRandomGenerator* random_; + + absl::BitGenRef random_; SharedStatistics* shared_stats_ = nullptr; const int watcher_id_; @@ -363,6 +372,10 @@ class LinearPropagator : public PropagatorInterface, Bitset64 in_queue_; std::deque propagation_queue_; + // Whether to only propagate linear constraints with exactly one non-fixed + // variable. This only applies before the first conflict is detected. + bool only_propagate_unit_linear_; + // Lin3 constraint that need to be processed to push lin2 bounds. SparseBitset lin3_ids_; diff --git a/ortools/sat/linear_relaxation.cc b/ortools/sat/linear_relaxation.cc index 1fdf3a78519..98caec8125e 100644 --- a/ortools/sat/linear_relaxation.cc +++ b/ortools/sat/linear_relaxation.cc @@ -35,6 +35,7 @@ #include "ortools/base/mathutil.h" #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/circuit.h" // for ReindexArcs. #include "ortools/sat/clause.h" #include "ortools/sat/cp_model.pb.h" @@ -358,8 +359,7 @@ void LinearizeComplexLinear1(Model* m, const CpModelProto& model_proto, if (mapping == nullptr || integer_trail == nullptr) return; // Lets regroup all interval linear one by variable. - std::vector flat_vars; - std::vector flat_lin1s; + CompactVectorVectorBuilder var_to_lin1_builder; const int num_constraints = model_proto.constraints().size(); for (int c = 0; c < num_constraints; ++c) { const ConstraintProto& ct = model_proto.constraints(c); @@ -372,13 +372,12 @@ void LinearizeComplexLinear1(Model* m, const CpModelProto& model_proto, CHECK(RefIsPositive(proto_var)); if (!mapping->IsInteger(proto_var)) continue; - flat_vars.push_back(mapping->Integer(proto_var)); - flat_lin1s.push_back(c); + var_to_lin1_builder.Add(mapping->Integer(proto_var), c); } int num_added_constraints = 0; - CompactVectorVector var_to_lin1; - var_to_lin1.ResetFromFlatMapping(flat_vars, flat_lin1s); + const CompactVectorVector var_to_lin1( + var_to_lin1_builder); for (IntegerVariable var(0); var < var_to_lin1.size(); ++var) { const Domain var_domain = integer_trail->InitialVariableDomain(var); if (var_to_lin1[var].size() < 2) continue; @@ -1459,9 +1458,7 @@ void AppendLinearConstraintRelaxation(const ConstraintProto& ct, Model* model, const IntegerValue rhs_domain_min = IntegerValue(ct.linear().domain(0)); const IntegerValue rhs_domain_max = IntegerValue(ct.linear().domain(ct.linear().domain_size() - 1)); - if (rhs_domain_min == std::numeric_limits::min() && - rhs_domain_max == std::numeric_limits::max()) - return; + if (rhs_domain_min == kint64min && rhs_domain_max == kint64max) return; LinearConstraintBuilder lc(model, rhs_domain_min, rhs_domain_max); auto* mapping = model->Get(); @@ -2214,9 +2211,13 @@ LinearRelaxation ComputeLinearRelaxation(const CpModelProto& model_proto, // Linearize the at most one constraints. Note that we transform them // into maximum "at most one" first and we removes redundant ones. - m->GetOrCreate()->TransformIntoMaxCliques( - &relaxation.at_most_ones, - SafeDoubleToInt64(params.merge_at_most_one_work_limit())); + if (!m->GetOrCreate()->TransformIntoMaxCliques( + &relaxation.at_most_ones, + SafeDoubleToInt64(params.merge_at_most_one_work_limit()))) { + m->GetOrCreate()->NotifyThatModelIsUnsat(); + return relaxation; + } + for (const std::vector& at_most_one : relaxation.at_most_ones) { if (at_most_one.empty()) continue; diff --git a/ortools/sat/lp_utils.cc b/ortools/sat/lp_utils.cc index 08001bdbc7f..f66114fd2ff 100644 --- a/ortools/sat/lp_utils.cc +++ b/ortools/sat/lp_utils.cc @@ -29,6 +29,7 @@ #include "absl/types/span.h" #include "ortools/base/log_severity.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/glop/lp_solver.h" #include "ortools/glop/parameters.pb.h" #include "ortools/linear_solver/linear_solver.pb.h" @@ -259,7 +260,7 @@ int64_t FindRationalFactor(double x, int64_t limit, double tolerance) { } x = 1 / x; const double floored_x = std::floor(x); - if (floored_x >= static_cast(std::numeric_limits::max())) { + if (floored_x >= static_cast(kint64max)) { return 0; } const int64_t new_q = @@ -1206,12 +1207,11 @@ absl::Status ConstraintScaler::ScaleAndAddConstraint( // on an int64_t, if the scaled bounds are too large, the constraint is either // always true or always false. const Fractional scaled_lb = std::ceil(lb * scaling_factor); - if (lb == kInfinity || scaled_lb >= std::numeric_limits::max()) { + if (lb == kInfinity || scaled_lb >= kint64max) { // Corner case: infeasible model. - arg->add_domain(std::numeric_limits::max()); - } else if (lb == -kInfinity || - scaled_lb <= std::numeric_limits::min()) { - arg->add_domain(std::numeric_limits::min()); + arg->add_domain(kint64max); + } else if (lb == -kInfinity || scaled_lb <= kint64min) { + arg->add_domain(kint64min); } else { arg->add_domain(CeilRatio(IntegerValue(static_cast(scaled_lb)), IntegerValue(gcd)) @@ -1219,12 +1219,11 @@ absl::Status ConstraintScaler::ScaleAndAddConstraint( } const Fractional scaled_ub = std::floor(ub * scaling_factor); - if (ub == -kInfinity || scaled_ub <= std::numeric_limits::min()) { + if (ub == -kInfinity || scaled_ub <= kint64min) { // Corner case: infeasible model. - arg->add_domain(std::numeric_limits::min()); - } else if (ub == kInfinity || - scaled_ub >= std::numeric_limits::max()) { - arg->add_domain(std::numeric_limits::max()); + arg->add_domain(kint64min); + } else if (ub == kInfinity || scaled_ub >= kint64max) { + arg->add_domain(kint64max); } else { arg->add_domain(FloorRatio(IntegerValue(static_cast(scaled_ub)), IntegerValue(gcd)) diff --git a/ortools/sat/lrat_checker.cc b/ortools/sat/lrat_checker.cc index 4e8fcf6a1ac..3bde65f23bb 100644 --- a/ortools/sat/lrat_checker.cc +++ b/ortools/sat/lrat_checker.cc @@ -312,15 +312,27 @@ bool LratChecker::DebugCheckProofClauseId(ClausePtr clause, absl::StrCat("proof clause not found: ", proof_clause, " ", absl::StrJoin(proof_clause.GetLiterals(), ","))); } - absl::btree_set expected_literals; - for (const Literal literal : it->second) { - expected_literals.insert(literal); - } - absl::btree_set actual_literals; - for (const Literal literal : proof_clause.GetLiterals()) { - actual_literals.insert(literal); + bool difference_found = + it->second.size() != proof_clause.GetLiterals().size(); + debug_scratch_literal_set_.Resize(LiteralIndex(2 * num_variables_)); + if (!difference_found) { + for (const Literal literal : it->second) { + debug_scratch_literal_set_.Set(literal); + } + for (const Literal literal : proof_clause.GetLiterals()) { + if (!debug_scratch_literal_set_[literal]) { + difference_found = true; + break; + } + debug_scratch_literal_set_.Clear(literal); + } } - if (actual_literals != expected_literals) { + if (difference_found) { + debug_scratch_literal_set_.ClearAndResize(LiteralIndex(2 * num_variables_)); + const absl::btree_set actual_literals( + proof_clause.GetLiterals().begin(), proof_clause.GetLiterals().end()); + const absl::btree_set expected_literals(it->second.begin(), + it->second.end()); return Error( clause, absl::StrCat("proof clause ", proof_clause, ": unexpected literals ", diff --git a/ortools/sat/lrat_checker.h b/ortools/sat/lrat_checker.h index 11061563d8b..40e4558a6d8 100644 --- a/ortools/sat/lrat_checker.h +++ b/ortools/sat/lrat_checker.h @@ -204,6 +204,7 @@ class LratChecker { // Only used if DEBUG_MODE is true. absl::flat_hash_map> debug_clause_by_ptr_; + SparseBitset debug_scratch_literal_set_; SharedStatistics* stats_; }; diff --git a/ortools/sat/lrat_checker_test.cc b/ortools/sat/lrat_checker_test.cc index d2764e377d8..39bbc8fa4f9 100644 --- a/ortools/sat/lrat_checker_test.cc +++ b/ortools/sat/lrat_checker_test.cc @@ -226,6 +226,33 @@ TEST(LratCheckerTest, CheckSuccessWithRatClausesExtensions) { LratChecker::RatClauses{c7, {c6}}})); } +TEST(LratCheckerTest, VariableAdditionWithRatProof) { + Model model; + ClauseFactory factory; + + LratChecker& checker = *model.GetOrCreate(); + checker.EnableRatProofs(); + + // Add a new variable x4 <=> x1 ^ x2 ^ not(x3), assuming that x1, x2, x3 + // already exist. + // First add the clauses not(x4) v x1, not(x4) v x2, and not(x4) v not(x3). + // They can be proved with an empty RAT proof since there are no clauses + // containing x4, the negation of the pivot not(x4). + const ClausePtr c1 = factory.NewClause({-4, +1}); + const ClausePtr c2 = factory.NewClause({-4, +2}); + const ClausePtr c3 = factory.NewClause({-4, -3}); + EXPECT_TRUE(checker.AddInferredClause(c1, {}, {})); + EXPECT_TRUE(checker.AddInferredClause(c2, {}, {})); + EXPECT_TRUE(checker.AddInferredClause(c3, {}, {})); + // Now add the clause x1 ^ x2 ^ not(x3) => x4. This requires a RAT proof + // listing all the clauses C containing not(x4), the negation of the pivot x4. + // Here each resolvant C has two pairs of complementary literals with c4, + // hence an empty rup_clauses is sufficient in each RatClauses. + const ClausePtr c4 = factory.NewClause({+4, -1, -2, +3}); + EXPECT_TRUE(checker.AddInferredClause( + c4, {}, {{.resolvant = c1}, {.resolvant = c2}, {.resolvant = c3}})); +} + TEST(LratCheckerTest, ErrorStateIsSticky) { Model model; ClauseFactory factory; diff --git a/ortools/sat/model.h b/ortools/sat/model.h index 8b2866fdf9c..28110a82af1 100644 --- a/ortools/sat/model.h +++ b/ortools/sat/model.h @@ -196,6 +196,12 @@ class Model { singletons_[type_id] = non_owned_class; } + template + void Unregister() { + const size_t type_id = FastTypeId(); + singletons_.erase(type_id); + } + const std::string& Name() const { return name_; } private: diff --git a/ortools/sat/no_overlap_2d_helper.cc b/ortools/sat/no_overlap_2d_helper.cc index 583e6331d29..68ffabc1232 100644 --- a/ortools/sat/no_overlap_2d_helper.cc +++ b/ortools/sat/no_overlap_2d_helper.cc @@ -206,6 +206,7 @@ void NoOverlap2DConstraintHelper::Reset( std::vector y_sizes; std::vector y_reason_for_presence; + std::vector old_to_new_box_index(NumBoxes(), -1); auto add_non_fixed_box = [&](int box_index) { x_starts.push_back(x_helper_->Starts()[box_index]); x_ends.push_back(x_helper_->Ends()[box_index]); @@ -225,6 +226,8 @@ void NoOverlap2DConstraintHelper::Reset( y_reason_for_presence.push_back(kNoLiteralIndex); } + DCHECK_EQ(old_to_new_box_index[box_index], -1); + old_to_new_box_index[box_index] = x_starts.size() - 1; return x_starts.size() - 1; }; @@ -248,7 +251,7 @@ void NoOverlap2DConstraintHelper::Reset( active_bounding_boxes.reserve(new_num_boxes); active_box_indexes.reserve(new_num_boxes); DCHECK_EQ(x_helper_->CurrentDecisionLevel(), 0); - for (int box : non_fixed_box_indexes) { + for (const int box : non_fixed_box_indexes) { if (IsAbsent(box)) continue; active_bounding_boxes.push_back(GetBoundingRectangle(box)); // At level zero we can do a stronger check whether a box is fixed, since @@ -305,6 +308,16 @@ void NoOverlap2DConstraintHelper::Reset( y_helper_->SetEnforcementId(enforcement_id_); x_demands_helper_ = nullptr; y_demands_helper_ = nullptr; + + // Remap the watch indices. + for (int& ref : watch_index_to_box_) { + if (ref == -1) continue; + if (old_to_new_box_index[ref] == -1) { + ref = -1; + } else { + ref = old_to_new_box_index[ref]; + } + } } bool NoOverlap2DConstraintHelper::IsEnforced() const { @@ -312,11 +325,34 @@ bool NoOverlap2DConstraintHelper::IsEnforced() const { EnforcementStatus::IS_ENFORCED; } +bool NoOverlap2DConstraintHelper::IncrementalPropagate( + const std::vector& watch_indices) { + if (!IsEnforced()) return true; + for (const int id : propagators_watching_) { + watcher_->CallOnNextPropagate(id); + } + + if (x_helper_->CurrentDecisionLevel() == 0) { + // We will recompute the cache in any case, so we can return here. + return Propagate(); + } + + for (const int i : watch_indices) { + const int local_index = watch_index_to_box_[i]; + if (local_index != -1) { + x_helper_->RecomputeCache(local_index); + y_helper_->RecomputeCache(local_index); + } + } + return true; +} + bool NoOverlap2DConstraintHelper::Propagate() { if (!IsEnforced()) return true; for (const int id : propagators_watching_) { watcher_->CallOnNextPropagate(id); } + if (!x_helper_->Propagate() || !y_helper_->Propagate()) return false; if (x_helper_->CurrentDecisionLevel() == 0) { @@ -376,6 +412,7 @@ void NoOverlap2DConstraintHelper::RegisterWith( absl::Span enforcement_literals) { const int id = watcher->Register(this); const int num_boxes = NumBoxes(); + watch_index_to_box_.assign(num_boxes, -1); for (int b = 0; b < num_boxes; ++b) { if (x_helper_->IsOptional(b)) { watcher->WatchLiteral(x_helper_->PresenceLiteral(b), id); @@ -383,14 +420,25 @@ void NoOverlap2DConstraintHelper::RegisterWith( if (y_helper_->IsOptional(b)) { watcher->WatchLiteral(y_helper_->PresenceLiteral(b), id); } - watcher->WatchIntegerVariable(x_helper_->Sizes()[b].var, id); - watcher->WatchIntegerVariable(x_helper_->Starts()[b].var, id); - watcher->WatchIntegerVariable(x_helper_->Ends()[b].var, id); - watcher->WatchIntegerVariable(y_helper_->Sizes()[b].var, id); - watcher->WatchIntegerVariable(y_helper_->Starts()[b].var, id); - watcher->WatchIntegerVariable(y_helper_->Ends()[b].var, id); + + // Initially there is no remapping. + watch_index_to_box_[b] = b; + + // It is important to only update the cache of start/end/etc... + // incrementally otherwise we will be in O(num_boxes) each time we push a + // bound. + // + // TODO(user): split x/y watching, but then we need to be careful when + // we swap x/y ... + watcher->WatchIntegerVariable(x_helper_->Sizes()[b].var, id, b); + watcher->WatchIntegerVariable(x_helper_->Starts()[b].var, id, b); + watcher->WatchIntegerVariable(x_helper_->Ends()[b].var, id, b); + watcher->WatchIntegerVariable(y_helper_->Sizes()[b].var, id, b); + watcher->WatchIntegerVariable(y_helper_->Starts()[b].var, id, b); + watcher->WatchIntegerVariable(y_helper_->Ends()[b].var, id, b); } watcher->SetPropagatorPriority(id, 0); + enforcement_id_ = enforcement_helper_.Register(enforcement_literals, watcher, id); x_helper_->SetEnforcementId(enforcement_id_); diff --git a/ortools/sat/no_overlap_2d_helper.h b/ortools/sat/no_overlap_2d_helper.h index b9b893fbd5b..143f2270e96 100644 --- a/ortools/sat/no_overlap_2d_helper.h +++ b/ortools/sat/no_overlap_2d_helper.h @@ -224,6 +224,7 @@ class NoOverlap2DConstraintHelper : public PropagatorInterface { int NumBoxes() const { return x_helper_->NumTasks(); } + bool IncrementalPropagate(const std::vector& watch_indices) override; bool Propagate() override; // Note that the helpers are only valid until the next call to @@ -249,6 +250,14 @@ class NoOverlap2DConstraintHelper : public PropagatorInterface { CompactVectorVector connected_components_; + // The watcher indices never changes, but box indices are remapped by Reset(). + // Note that as boxes become fixed at the root level, the corresponding entry + // will be set to -1 here. + // + // TODO(user): We should probably remove the entry in the watch list of the + // GenericLiteralWatcher too. + std::vector watch_index_to_box_; + bool axes_are_swapped_; std::unique_ptr x_helper_; std::unique_ptr y_helper_; diff --git a/ortools/sat/opb_reader.h b/ortools/sat/opb_reader.h index f38431e3521..88e3b47f236 100644 --- a/ortools/sat/opb_reader.h +++ b/ortools/sat/opb_reader.h @@ -32,6 +32,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/util/filelineiter.h" @@ -119,8 +120,8 @@ class OpbReader { struct PbConstraint { std::vector terms; PbConstraintType type = UNDEFINED_OPERATION; - int64_t rhs = std::numeric_limits::min(); - int64_t soft_cost = std::numeric_limits::max(); + int64_t rhs = kint64min; + int64_t soft_cost = kint64max; }; // Since the problem name is not stored in the opb format, we infer it from @@ -227,7 +228,7 @@ class OpbReader { std::string ValidateModel() { // Normalize and validate constraints. for (const PbConstraint& constraint : constraints_) { - if (constraint.rhs == std::numeric_limits::min()) { + if (constraint.rhs == kint64min) { return "constraint error: undefined rhs"; } @@ -337,7 +338,7 @@ class OpbReader { } if (constraint.type == GE_OPERATION) { lin->add_domain(constraint.rhs); - lin->add_domain(std::numeric_limits::max()); + lin->add_domain(kint64max); } else if (constraint.type == EQ_OPERATION) { lin->add_domain(constraint.rhs); lin->add_domain(constraint.rhs); @@ -345,7 +346,7 @@ class OpbReader { LOG(FATAL) << "Unsupported operation: " << constraint.type; } - if (constraint.soft_cost != std::numeric_limits::max()) { + if (constraint.soft_cost != kint64max) { const int violation_var_index = model->variables_size(); IntegerVariableProto* violation_var = model->add_variables(); violation_var->add_domain(0); @@ -370,7 +371,7 @@ class OpbReader { if (top_cost_.has_value()) { CpObjectiveProto* obj = model->mutable_objective(); - obj->add_domain(std::numeric_limits::min()); + obj->add_domain(kint64min); obj->add_domain(top_cost_.value()); } } diff --git a/ortools/sat/parameters_validation.cc b/ortools/sat/parameters_validation.cc index 566a8b5fc8f..5350a3aeeac 100644 --- a/ortools/sat/parameters_validation.cc +++ b/ortools/sat/parameters_validation.cc @@ -16,10 +16,10 @@ #include #include -#include #include #include "absl/strings/str_cat.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model_search.h" #include "ortools/sat/sat_parameters.pb.h" @@ -122,7 +122,7 @@ std::string ValidateParameters(const SatParameters& params) { // validation. It is however not open sourced. TEST_IN_RANGE(mip_max_activity_exponent, 1, 62); TEST_IN_RANGE(mip_max_bound, 0, 1e17); - TEST_IN_RANGE(solution_pool_size, 1, std::numeric_limits::max()); + TEST_IN_RANGE(solution_pool_size, 1, kint32max); // Feasibility jump. TEST_NOT_NAN(feasibility_jump_decay); diff --git a/ortools/sat/precedences.cc b/ortools/sat/precedences.cc index ac5f61f0884..ce77cc802d8 100644 --- a/ortools/sat/precedences.cc +++ b/ortools/sat/precedences.cc @@ -247,6 +247,7 @@ IntegerValue RootLevelLinear2Bounds::GetUpperBoundNoTrail( std::vector> RootLevelLinear2Bounds::GetSortedNonTrivialUpperBounds() const { std::vector> result; + result.reserve(best_upper_bounds_.size()); for (LinearExpression2Index index = LinearExpression2Index{0}; index < best_upper_bounds_.size(); ++index) { const IntegerValue ub = best_upper_bounds_[index]; @@ -802,14 +803,14 @@ void ConditionalLinear2Bounds::AddPartialRelation(Literal lit, void ConditionalLinear2Bounds::Build() { DCHECK(!is_built_); is_built_ = true; - std::vector> literal_key_values; + CompactVectorVectorBuilder lit_to_relations_builder; + lit_to_relations_builder.ReserveNumItems(num_enforced_relations_); const int num_relations = relations_.size(); - literal_key_values.reserve(num_enforced_relations_); for (int i = 0; i < num_relations; ++i) { const Relation& r = relations_[i]; - literal_key_values.emplace_back(r.enforcement.Index(), i); + lit_to_relations_builder.Add(r.enforcement.Index(), i); } - lit_to_relations_.ResetFromPairs(literal_key_values); + lit_to_relations_.ResetFromBuilder(lit_to_relations_builder); lit_to_relations_.Add({}); // One extra unit size to make sure the negation // cannot be out of bounds in lit_to_relations_. @@ -1205,19 +1206,18 @@ int GreaterThanAtLeastOneOfDetector::AddGreaterThanAtLeastOneOfConstraints( CompactVectorVector implied_bounds_by_literal; { + CompactVectorVectorBuilder + implied_bounds_by_literal_builder; const auto& all_implied_bounds = implied_bounds_.GetModelImpliedBounds(); - std::vector implied_bounds_conditions; - std::vector implied_bounds_integer_lit; - implied_bounds_conditions.reserve(all_implied_bounds.size()); - implied_bounds_integer_lit.reserve(all_implied_bounds.size()); + implied_bounds_by_literal_builder.ReserveNumItems( + all_implied_bounds.size()); for (const auto& [literal_var_pair, bound] : all_implied_bounds) { - implied_bounds_conditions.push_back(literal_var_pair.first); - implied_bounds_integer_lit.push_back( + implied_bounds_by_literal_builder.Add( + literal_var_pair.first, IntegerLiteral::GreaterOrEqual(literal_var_pair.second, bound)); } - implied_bounds_by_literal.ResetFromFlatMapping( - std::move(implied_bounds_conditions), - std::move(implied_bounds_integer_lit), 2 * solver->NumVariables()); + implied_bounds_by_literal.ResetFromBuilder( + implied_bounds_by_literal_builder, 2 * solver->NumVariables()); } // We have two possible approaches. For now, we prefer the first one except if diff --git a/ortools/sat/precedences.h b/ortools/sat/precedences.h index e014364c8db..e336196c50e 100644 --- a/ortools/sat/precedences.h +++ b/ortools/sat/precedences.h @@ -114,6 +114,17 @@ class Linear2Indices { // expression must already be canonicalized and divided by its GCD. LinearExpression2Index GetIndex(LinearExpression2 expr) const; + // Returns the number of positive linear2 expressions that have a potentially + // non-trivial bound. + // These expressions can be obtained with GetExpression(2 * i), + // for all i in [0,NumStoredPositiveLinear2()[. + LinearExpression2Index NumStoredPositiveLinear2() const { + return LinearExpression2Index(exprs_.size()); + } + + // Returns the linear expression corresponding to the given index. The index + // must be less than 2 * NumStoredPositiveLinear2(). Even indices correspond + // to positive linear expressions, and odd indices to their negation. LinearExpression2 GetExpression(LinearExpression2Index index) const; // Return all positive linear2 expressions that have a potentially non-trivial diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 7b35d9a1825..784e3b8318e 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -16,7 +16,7 @@ #include #include #include -#include +#include #include #include #include @@ -37,6 +37,7 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "ortools/base/log_severity.h" +#include "ortools/base/types.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_checker.h" @@ -70,13 +71,178 @@ int SavedLiteral::Get(PresolveContext* context) const { int SavedVariable::Get() const { return ref_; } +void LazyConstraintVariableGraph::UpdateLinear1Usage(const ConstraintProto& ct, + int c) { + const int old_var = constraint_to_linear1_var_[c]; + if (old_var >= 0) { + var_to_num_linear1_[old_var]--; + DCHECK_GE(var_to_num_linear1_[old_var], 0); + } + if (ct.constraint_case() == ConstraintProto::ConstraintCase::kLinear && + ct.linear().vars().size() == 1) { + const int var = PositiveRef(ct.linear().vars(0)); + constraint_to_linear1_var_[c] = var; + var_to_num_linear1_[var]++; + } else { + constraint_to_linear1_var_[c] = -1; + } +} + +void LazyConstraintVariableGraph::MaybeResizeIntervalData() { + // Lazy allocation so that we only do that if there are some interval. + const int num_constraints = constraint_to_vars_.size(); + if (constraint_to_intervals_.size() != num_constraints) { + constraint_to_intervals_.resize(num_constraints); + interval_usage_.resize(num_constraints); + } +} + +void LazyConstraintVariableGraph::AddVariableUsage(int c) { + const ConstraintProto& ct = cp_model_.constraints(c); + + constraint_to_vars_[c] = UsedVariables(ct); + for (const int v : constraint_to_vars_[c]) { + DCHECK_LT(v, var_to_constraints_.size()); + DCHECK(!var_was_removed_[v]); + var_to_constraints_[v].insert(c); + } + + std::vector used_interval = UsedIntervals(ct); + if (!used_interval.empty()) { + MaybeResizeIntervalData(); + constraint_to_intervals_[c].swap(used_interval); + for (const int i : constraint_to_intervals_[c]) interval_usage_[i]++; + } + + UpdateLinear1Usage(ct, c); + +#ifdef CHECK_HINT + // Crash if the loaded hint is infeasible for this constraint. + // This is helpful to debug a wrong presolve that kill a feasible solution. + if (cp_model_.has_solution_hint() && solution_crush_.SolutionIsLoaded() && + !ConstraintIsFeasible(cp_model_, ct, solution_crush_.GetVarValues())) { + LOG(FATAL) << "Hint infeasible for constraint #" << c << " : " + << ct.ShortDebugString(); + } +#endif +} + +void LazyConstraintVariableGraph::EraseFromVarToConstraint(int var, int c) { + var_to_constraints_[var].erase(c); + if (var_to_constraints_[var].size() <= 3) { + var_with_reduced_small_degree_.Set(var); + } +} + +void LazyConstraintVariableGraph::UpdateConstraintVariableUsage(int c) { + DCHECK_EQ(constraint_to_vars_.size(), cp_model_.constraints_size()); + const ConstraintProto& ct = cp_model_.constraints(c); + + // We don't optimize the interval usage as this is not super frequent. + std::vector used_interval = UsedIntervals(ct); + if (c < constraint_to_intervals_.size() || !used_interval.empty()) { + MaybeResizeIntervalData(); + for (const int i : constraint_to_intervals_[c]) interval_usage_[i]--; + constraint_to_intervals_[c].swap(used_interval); + for (const int i : constraint_to_intervals_[c]) interval_usage_[i]++; + } + + // For the variables, we avoid an erase() followed by an insert() for the + // variables that didn't change. + std::vector new_usage = UsedVariables(ct); + const absl::Span old_usage = constraint_to_vars_[c]; + const int old_size = old_usage.size(); + int i = 0; + for (const int var : new_usage) { + DCHECK(!VariableWasRemoved(var)); + while (i < old_size && old_usage[i] < var) { + EraseFromVarToConstraint(old_usage[i], c); + ++i; + } + if (i < old_size && old_usage[i] == var) { + ++i; + } else { + var_to_constraints_[var].insert(c); + } + } + for (; i < old_size; ++i) { + EraseFromVarToConstraint(old_usage[i], c); + } + constraint_to_vars_[c].swap(new_usage); + + UpdateLinear1Usage(ct, c); + +#ifdef CHECK_HINT + // Crash if the loaded hint is infeasible for this constraint. + // This is helpful to debug a wrong presolve that kill a feasible solution. + if (cp_model_.has_solution_hint() && solution_crush_.SolutionIsLoaded() && + !ConstraintIsFeasible(cp_model_, ct, solution_crush_.GetVarValues())) { + LOG(FATAL) << "Hint infeasible for constraint #" << c << " : " + << ct.ShortDebugString(); + } +#endif +} + +// TODO(user): Also test var_to_constraints_ !! +bool PresolveContext::ConstraintVariableUsageIsConsistent() const { + // We do not care in these cases. + if (is_unsat_) return true; + if (time_limit_->LimitReached()) return true; + + const int num_vars = working_model_->variables_size(); + std::vector linear1_count(num_vars, 0); + for (int c = 0; c < working_model_->constraints_size(); ++c) { + const ConstraintProto& ct = working_model_->constraints(c); + if (graph_.ConstraintToVars(c) != UsedVariables(ct)) { + LOG(INFO) << "Wrong variables usage for constraint: \n" + << ProtobufDebugString(ct) + << " old_size: " << graph_.ConstraintToVars(c).size(); + return false; + } + if (ct.constraint_case() == ConstraintProto::kLinear && + ct.linear().vars().size() == 1) { + linear1_count[PositiveRef(ct.linear().vars(0))]++; + if (graph_.ConstraintToLinear1Var(c) != + PositiveRef(ct.linear().vars(0))) { + LOG(INFO) << "Wrong variables for linear1: \n" + << ProtobufDebugString(ct) + << " saved_var: " << graph_.ConstraintToLinear1Var(c); + return false; + } + } + } + int num_in_objective = 0; + for (int v = 0; v < num_vars; ++v) { + if (linear1_count[v] != graph_.VarToNumLinear1(v)) { + LOG(INFO) << "Variable " << v << " has wrong linear1 count!" + << " stored: " << graph_.VarToNumLinear1(v) + << " actual: " << linear1_count[v]; + return false; + } + if (graph_.VarToConstraints(v).contains(kObjectiveConstraint)) { + ++num_in_objective; + if (!objective_map_.contains(v)) { + LOG(INFO) << "Variable " << v + << " is marked as part of the objective but isn't."; + return false; + } + } + } + if (num_in_objective != objective_map_.size()) { + LOG(INFO) << "Not all variables are marked as part of the objective"; + return false; + } + + return true; +} + void PresolveContext::ClearStats() { stats_by_rule_name_.clear(); } int PresolveContext::NewIntVar(const Domain& domain) { - IntegerVariableProto* const var = working_model->add_variables(); + IntegerVariableProto* const var = working_model_->add_variables(); FillDomainInProto(domain, var); InitializeNewDomains(); - return working_model->variables_size() - 1; + return working_model_->variables_size() - 1; } int PresolveContext::NewIntVarWithDefinition( @@ -93,7 +259,7 @@ int PresolveContext::NewIntVarWithDefinition( auto* new_linear = append_constraint_to_mapping_model ? NewMappingConstraint(__FILE__, __LINE__)->mutable_linear() - : working_model->add_constraints()->mutable_linear(); + : working_model_->add_constraints()->mutable_linear(); for (const auto [var, coeff] : definition) { new_linear->add_vars(var); new_linear->add_coeffs(coeff); @@ -102,19 +268,16 @@ int PresolveContext::NewIntVarWithDefinition( new_linear->add_coeffs(-1); new_linear->add_domain(0); new_linear->add_domain(0); - if (PossibleIntegerOverflow(*working_model, new_linear->vars(), + if (PossibleIntegerOverflow(*working_model_, new_linear->vars(), new_linear->coeffs())) { UpdateRuleStats("TODO new_var_definition : possible overflow."); if (append_constraint_to_mapping_model) { mapping_model->mutable_constraints()->RemoveLast(); } else { - working_model->mutable_constraints()->RemoveLast(); + working_model_->mutable_constraints()->RemoveLast(); } return -1; } - if (!append_constraint_to_mapping_model) { - UpdateNewConstraintsVariableUsage(); - } solution_crush_.SetVarToLinearExpression(new_var, definition); return new_var; @@ -151,29 +314,33 @@ int PresolveContext::GetFalseLiteral() { return NegatedRef(GetTrueLiteral()); } ConstraintProto* PresolveContext::AddEnforcedConstraint( absl::Span enforcement_literals) { - ConstraintProto* const new_ct = working_model->add_constraints(); + ConstraintProto* const new_ct = working_model_->add_constraints(); *new_ct->mutable_enforcement_literal() = {enforcement_literals.begin(), enforcement_literals.end()}; return new_ct; } -ConstraintProto* PresolveContext::AddEnforcedConstraint(ConstraintProto* ct) { - ConstraintProto* const new_ct = working_model->add_constraints(); - *new_ct->mutable_enforcement_literal() = ct->enforcement_literal(); - return new_ct; +ConstraintProto* PresolveContext::AddEnforcedConstraint( + const ConstraintProto& ct) { + return AddEnforcedConstraint(ct.enforcement_literal()); +} + +ConstraintProto* PresolveContext::AddEnforcedConstraint( + const ConstraintProto* ct) { + return AddEnforcedConstraint(ct->enforcement_literal()); } // a => b. void PresolveContext::AddImplication(int a, int b) { if (a == b) return; - ConstraintProto* const ct = working_model->add_constraints(); + ConstraintProto* const ct = working_model_->add_constraints(); ct->add_enforcement_literal(a); ct->mutable_bool_and()->add_literals(b); } // b => x in [lb, ub]. void PresolveContext::AddImplyInDomain(int b, int x, const Domain& domain) { - ConstraintProto* const imply = working_model->add_constraints(); + ConstraintProto* const imply = working_model_->add_constraints(); // Doing it like this seems to use slightly less memory. // TODO(user): Find the best way to create such small proto. @@ -186,7 +353,7 @@ void PresolveContext::AddImplyInDomain(int b, int x, const Domain& domain) { void PresolveContext::AddImplyInDomain(int b, const LinearExpressionProto& expr, const Domain& domain) { - ConstraintProto* const imply = working_model->add_constraints(); + ConstraintProto* const imply = working_model_->add_constraints(); imply->mutable_enforcement_literal()->Resize(1, b); LinearConstraintProto* mutable_linear = imply->mutable_linear(); @@ -288,18 +455,6 @@ int64_t PresolveContext::FixedValue(const LinearExpressionProto& expr) const { return result; } -std::optional PresolveContext::FixedValueOrNullopt( - const LinearExpressionProto& expr) const { - int64_t result = expr.offset(); - for (int i = 0; i < expr.vars_size(); ++i) { - if (expr.coeffs(i) == 0) continue; - const Domain& domain = domains_[expr.vars(i)]; - if (!domain.IsFixed()) return std::nullopt; - result += expr.coeffs(i) * domain.Min(); - } - return result; -} - Domain PresolveContext::DomainSuperSetOf( const LinearExpressionProto& expr) const { Domain result(expr.offset()); @@ -351,7 +506,7 @@ bool PresolveContext::ExpressionIsALiteral(const LinearExpressionProto& expr, // Note that we only support converted intervals. bool PresolveContext::IntervalIsConstant(int ct_ref) const { - const ConstraintProto& proto = working_model->constraints(ct_ref); + const ConstraintProto& proto = working_model_->constraints(ct_ref); if (!proto.enforcement_literal().empty()) return false; if (!IsFixed(proto.interval().start())) return false; if (!IsFixed(proto.interval().size())) return false; @@ -365,7 +520,7 @@ std::string PresolveContext::IntervalDebugString(int ct_ref) const { EndMax(ct_ref), ")"); } else if (ConstraintIsOptional(ct_ref)) { const int literal = - working_model->constraints(ct_ref).enforcement_literal(0); + working_model_->constraints(ct_ref).enforcement_literal(0); if (SizeMin(ct_ref) == SizeMax(ct_ref)) { return absl::StrCat("interval_", ct_ref, "(lit=", literal, ", ", StartMin(ct_ref), " --(", SizeMin(ct_ref), ")--> ", @@ -387,37 +542,37 @@ std::string PresolveContext::IntervalDebugString(int ct_ref) const { int64_t PresolveContext::StartMin(int ct_ref) const { const IntervalConstraintProto& interval = - working_model->constraints(ct_ref).interval(); + working_model_->constraints(ct_ref).interval(); return MinOf(interval.start()); } int64_t PresolveContext::StartMax(int ct_ref) const { const IntervalConstraintProto& interval = - working_model->constraints(ct_ref).interval(); + working_model_->constraints(ct_ref).interval(); return MaxOf(interval.start()); } int64_t PresolveContext::EndMin(int ct_ref) const { const IntervalConstraintProto& interval = - working_model->constraints(ct_ref).interval(); + working_model_->constraints(ct_ref).interval(); return MinOf(interval.end()); } int64_t PresolveContext::EndMax(int ct_ref) const { const IntervalConstraintProto& interval = - working_model->constraints(ct_ref).interval(); + working_model_->constraints(ct_ref).interval(); return MaxOf(interval.end()); } int64_t PresolveContext::SizeMin(int ct_ref) const { const IntervalConstraintProto& interval = - working_model->constraints(ct_ref).interval(); + working_model_->constraints(ct_ref).interval(); return MinOf(interval.size()); } int64_t PresolveContext::SizeMax(int ct_ref) const { const IntervalConstraintProto& interval = - working_model->constraints(ct_ref).interval(); + working_model_->constraints(ct_ref).interval(); return MaxOf(interval.size()); } @@ -426,9 +581,8 @@ int64_t PresolveContext::SizeMax(int ct_ref) const { // be the affine defining one. And in this case the code using this function // should do the proper stuff. bool PresolveContext::VariableIsUnique(int ref) const { - if (!ConstraintVariableGraphIsUpToDate()) return false; const int var = PositiveRef(ref); - return var_to_constraints_[var].size() == 1; + return graph_.VarToConstraints(var).size() == 1; } bool PresolveContext::VariableIsUniqueAndRemovable(int ref) const { @@ -437,10 +591,9 @@ bool PresolveContext::VariableIsUniqueAndRemovable(int ref) const { } bool PresolveContext::VariableWithCostIsUnique(int ref) const { - if (!ConstraintVariableGraphIsUpToDate()) return false; const int var = PositiveRef(ref); - return var_to_constraints_[var].size() == 2 && - var_to_constraints_[var].contains(kObjectiveConstraint); + return graph_.VarToConstraints(var).size() == 2 && + graph_.VarToConstraints(var).contains(kObjectiveConstraint); } // Tricky: Same remark as for VariableIsUniqueAndRemovable(). @@ -448,7 +601,6 @@ bool PresolveContext::VariableWithCostIsUnique(int ref) const { // Also if the objective domain is constraining, we can't have a preferred // direction, so we cannot easily remove such variable. bool PresolveContext::VariableWithCostIsUniqueAndRemovable(int ref) const { - if (!ConstraintVariableGraphIsUpToDate()) return false; const int var = PositiveRef(ref); return !params_.keep_all_feasible_solutions_in_presolve() && !objective_domain_is_constraining_ && VariableWithCostIsUnique(var); @@ -457,56 +609,29 @@ bool PresolveContext::VariableWithCostIsUniqueAndRemovable(int ref) const { // Here, even if the variable is equivalent to others, if its affine defining // constraints where removed, then it is not needed anymore. bool PresolveContext::VariableIsNotUsedAnymore(int ref) const { - if (!ConstraintVariableGraphIsUpToDate()) return false; - return var_to_constraints_[PositiveRef(ref)].empty(); + return graph_.VarToConstraints(PositiveRef(ref)).empty(); } void PresolveContext::MarkVariableAsRemoved(int var) { DCHECK(RefIsPositive(var)); - var_was_removed_[var] = true; -} - -// Note(user): I added an indirection and a function for this to be able to -// display debug information when this return true. This should actually rarely -// return true in the case where it is used. -bool PresolveContext::VariableWasRemoved(int ref) const { - // TODO(user): clean this up, but we have a lot of usage with literal... - const int var = PositiveRef(ref); - - // It is okay to reuse removed fixed variable. - if (IsFixed(var)) return false; - if (!var_was_removed_[var]) return false; - - if (DEBUG_MODE && !var_to_constraints_[var].empty()) { - SOLVER_LOG(logger_, "Variable ", var, - " was removed, yet it appears in some constraints!"); - SOLVER_LOG(logger_, "affine relation: ", AffineRelationDebugString(var)); - for (const int c : var_to_constraints_[var]) { - SOLVER_LOG(logger_, "constraint #", c, " : ", - c >= 0 - ? ProtobufShortDebugString(working_model->constraints(c)) - : ""); - } - } - return true; + graph_.MarkVariableAsRemoved(var); } bool PresolveContext::VariableIsOnlyUsedInEncodingAndMaybeInObjective( int var) const { CHECK(RefIsPositive(var)); - if (!ConstraintVariableGraphIsUpToDate()) return false; - if (var_to_num_linear1_[var] == 0) return false; - return var_to_num_linear1_[var] == var_to_constraints_[var].size() || - (var_to_constraints_[var].contains(kObjectiveConstraint) && - var_to_num_linear1_[var] + 1 == var_to_constraints_[var].size()); + if (graph_.VarToNumLinear1(var) == 0) return false; + return graph_.VarToNumLinear1(var) == graph_.VarToConstraints(var).size() || + (graph_.VarToConstraints(var).contains(kObjectiveConstraint) && + graph_.VarToNumLinear1(var) + 1 == + graph_.VarToConstraints(var).size()); } bool PresolveContext::VariableIsOnlyUsedInLinear1AndOneExtraConstraint( int var) const { - if (!ConstraintVariableGraphIsUpToDate()) return false; - if (var_to_num_linear1_[var] == 0) return false; + if (graph_.VarToNumLinear1(var) == 0) return false; CHECK(RefIsPositive(var)); - return var_to_num_linear1_[var] + 1 == var_to_constraints_[var].size(); + return graph_.VarToNumLinear1(var) + 1 == graph_.VarToConstraints(var).size(); } const Domain& PresolveContext::DomainOf(int var) const { @@ -627,7 +752,7 @@ ABSL_MUST_USE_RESULT bool PresolveContext::SetLiteralToTrue(int lit) { } bool PresolveContext::ConstraintIsInactive(int index) const { - const ConstraintProto& ct = working_model->constraints(index); + const ConstraintProto& ct = working_model_->constraints(index); if (ct.constraint_case() == ConstraintProto::ConstraintCase::CONSTRAINT_NOT_SET) { return true; @@ -655,7 +780,7 @@ bool PresolveContext::MarkConstraintAsFalse(ConstraintProto* ct, } bool PresolveContext::ConstraintIsOptional(int ct_ref) const { - const ConstraintProto& ct = working_model->constraints(ct_ref); + const ConstraintProto& ct = working_model_->constraints(ct_ref); bool contains_one_free_literal = false; for (const int literal : ct.enforcement_literal()) { if (LiteralIsFalse(literal)) return false; @@ -686,139 +811,6 @@ void PresolveContext::UpdateRuleStats(std::string_view name, int num_times) { } } -void PresolveContext::UpdateLinear1Usage(const ConstraintProto& ct, int c) { - const int old_var = constraint_to_linear1_var_[c]; - if (old_var >= 0) { - var_to_num_linear1_[old_var]--; - DCHECK_GE(var_to_num_linear1_[old_var], 0); - } - if (ct.constraint_case() == ConstraintProto::ConstraintCase::kLinear && - ct.linear().vars().size() == 1) { - const int var = PositiveRef(ct.linear().vars(0)); - constraint_to_linear1_var_[c] = var; - var_to_num_linear1_[var]++; - } else { - constraint_to_linear1_var_[c] = -1; - } -} - -void PresolveContext::MaybeResizeIntervalData() { - // Lazy allocation so that we only do that if there are some interval. - const int num_constraints = constraint_to_vars_.size(); - if (constraint_to_intervals_.size() != num_constraints) { - constraint_to_intervals_.resize(num_constraints); - interval_usage_.resize(num_constraints); - } -} - -void PresolveContext::AddVariableUsage(int c) { - const ConstraintProto& ct = working_model->constraints(c); - - constraint_to_vars_[c] = UsedVariables(ct); - for (const int v : constraint_to_vars_[c]) { - DCHECK_LT(v, var_to_constraints_.size()); - DCHECK(!VariableWasRemoved(v)); - var_to_constraints_[v].insert(c); - } - - std::vector used_interval = UsedIntervals(ct); - if (!used_interval.empty()) { - MaybeResizeIntervalData(); - constraint_to_intervals_[c].swap(used_interval); - for (const int i : constraint_to_intervals_[c]) interval_usage_[i]++; - } - - UpdateLinear1Usage(ct, c); - -#ifdef CHECK_HINT - // Crash if the loaded hint is infeasible for this constraint. - // This is helpful to debug a wrong presolve that kill a feasible solution. - if (working_model->has_solution_hint() && - solution_crush_.SolutionIsLoaded() && - !ConstraintIsFeasible(*working_model, ct, - solution_crush_.GetVarValues())) { - LOG(FATAL) << "Hint infeasible for constraint #" << c << " : " - << ct.ShortDebugString(); - } -#endif -} - -void PresolveContext::EraseFromVarToConstraint(int var, int c) { - var_to_constraints_[var].erase(c); - if (var_to_constraints_[var].size() <= 3) { - var_with_reduced_small_degree.Set(var); - } -} - -void PresolveContext::UpdateConstraintVariableUsage(int c) { - if (is_unsat_) return; - DCHECK_EQ(constraint_to_vars_.size(), working_model->constraints_size()); - const ConstraintProto& ct = working_model->constraints(c); - - // We don't optimize the interval usage as this is not super frequent. - std::vector used_interval = UsedIntervals(ct); - if (c < constraint_to_intervals_.size() || !used_interval.empty()) { - MaybeResizeIntervalData(); - for (const int i : constraint_to_intervals_[c]) interval_usage_[i]--; - constraint_to_intervals_[c].swap(used_interval); - for (const int i : constraint_to_intervals_[c]) interval_usage_[i]++; - } - - // For the variables, we avoid an erase() followed by an insert() for the - // variables that didn't change. - std::vector new_usage = UsedVariables(ct); - const absl::Span old_usage = constraint_to_vars_[c]; - const int old_size = old_usage.size(); - int i = 0; - for (const int var : new_usage) { - DCHECK(!VariableWasRemoved(var)); - while (i < old_size && old_usage[i] < var) { - EraseFromVarToConstraint(old_usage[i], c); - ++i; - } - if (i < old_size && old_usage[i] == var) { - ++i; - } else { - var_to_constraints_[var].insert(c); - } - } - for (; i < old_size; ++i) { - EraseFromVarToConstraint(old_usage[i], c); - } - constraint_to_vars_[c].swap(new_usage); - - UpdateLinear1Usage(ct, c); - -#ifdef CHECK_HINT - // Crash if the loaded hint is infeasible for this constraint. - // This is helpful to debug a wrong presolve that kill a feasible solution. - if (working_model->has_solution_hint() && - solution_crush_.SolutionIsLoaded() && - !ConstraintIsFeasible(*working_model, ct, - solution_crush_.GetVarValues())) { - LOG(FATAL) << "Hint infeasible for constraint #" << c << " : " - << ct.ShortDebugString(); - } -#endif -} - -bool PresolveContext::ConstraintVariableGraphIsUpToDate() const { - if (is_unsat_) return true; // We do not care in this case. - return constraint_to_vars_.size() == working_model->constraints_size(); -} - -void PresolveContext::UpdateNewConstraintsVariableUsage() { - if (is_unsat_) return; - const int old_size = constraint_to_vars_.size(); - const int new_size = working_model->constraints_size(); - DCHECK_LE(old_size, new_size); - constraint_to_vars_.resize(new_size); - constraint_to_linear1_var_.resize(new_size, -1); - for (int c = old_size; c < new_size; ++c) { - AddVariableUsage(c); - } -} - bool PresolveContext::HasUnusedAffineVariable() const { if (is_unsat_) return false; // We do not care in this case. if (params_.keep_all_feasible_solutions_in_presolve()) return false; @@ -826,7 +818,7 @@ bool PresolveContext::HasUnusedAffineVariable() const { // We can leave non-optimal stuff around if we reach the time limit. if (time_limit_->LimitReached()) return false; - for (int var = 0; var < working_model->variables_size(); ++var) { + for (int var = 0; var < working_model_->variables_size(); ++var) { if (VariableIsNotUsedAnymore(var)) continue; if (IsFixed(var)) continue; const auto& constraints = VarToConstraints(var); @@ -839,65 +831,6 @@ bool PresolveContext::HasUnusedAffineVariable() const { return false; } -// TODO(user): Also test var_to_constraints_ !! -bool PresolveContext::ConstraintVariableUsageIsConsistent() { - // We do not care in these cases. - if (is_unsat_) return true; - if (time_limit_->LimitReached()) return true; - - if (var_to_constraints_.size() != working_model->variables_size()) { - LOG(INFO) << "Wrong var_to_constraints_ size!"; - return false; - } - if (constraint_to_vars_.size() != working_model->constraints_size()) { - LOG(INFO) << "Wrong constraint_to_vars size!"; - return false; - } - std::vector linear1_count(var_to_constraints_.size(), 0); - for (int c = 0; c < constraint_to_vars_.size(); ++c) { - const ConstraintProto& ct = working_model->constraints(c); - if (constraint_to_vars_[c] != UsedVariables(ct)) { - LOG(INFO) << "Wrong variables usage for constraint: \n" - << ProtobufDebugString(ct) - << " old_size: " << constraint_to_vars_[c].size(); - return false; - } - if (ct.constraint_case() == ConstraintProto::kLinear && - ct.linear().vars().size() == 1) { - linear1_count[PositiveRef(ct.linear().vars(0))]++; - if (constraint_to_linear1_var_[c] != PositiveRef(ct.linear().vars(0))) { - LOG(INFO) << "Wrong variables for linear1: \n" - << ProtobufDebugString(ct) - << " saved_var: " << constraint_to_linear1_var_[c]; - return false; - } - } - } - int num_in_objective = 0; - for (int v = 0; v < var_to_constraints_.size(); ++v) { - if (linear1_count[v] != var_to_num_linear1_[v]) { - LOG(INFO) << "Variable " << v << " has wrong linear1 count!" - << " stored: " << var_to_num_linear1_[v] - << " actual: " << linear1_count[v]; - return false; - } - if (var_to_constraints_[v].contains(kObjectiveConstraint)) { - ++num_in_objective; - if (!objective_map_.contains(v)) { - LOG(INFO) << "Variable " << v - << " is marked as part of the objective but isn't."; - return false; - } - } - } - if (num_in_objective != objective_map_.size()) { - LOG(INFO) << "Not all variables are marked as part of the objective"; - return false; - } - - return true; -} - // If a Boolean variable (one with domain [0, 1]) appear in this affine // equivalence class, then we want its representative to be Boolean. Note that // this is always possible because a Boolean variable can never be equal to a @@ -914,8 +847,8 @@ bool PresolveContext::AddRelation(int x, int y, int64_t c, int64_t o, // Boolean, it must be the representative. if (std::abs(c) != 1) return repo->TryAdd(x, y, c, o); - CHECK(!VariableWasRemoved(x)); - CHECK(!VariableWasRemoved(y)); + CHECK(!graph_.VariableWasRemoved(PositiveRef(x))); + CHECK(!graph_.VariableWasRemoved(PositiveRef(y))); // To avoid integer overflow, we always want to use the representative with // the smallest domain magnitude. Otherwise we might express a variable in say @@ -981,9 +914,7 @@ bool PresolveContext::PropagateAffineRelation(int var, int rep, int64_t coeff, } void PresolveContext::RemoveAllVariablesFromAffineRelationConstraint() { - for (auto& ref_map : var_to_constraints_) { - ref_map.erase(kAffineRelationConstraint); - } + graph_.RemoveAllVariablesFromAffineRelationConstraint(); } void PresolveContext::RemoveNonRepresentativeAffineVariableIfUnused(int var) { @@ -1016,22 +947,20 @@ void PresolveContext::RemoveVariableFromAffineRelation(int var) { CHECK(RefIsPositive(var)); CHECK_NE(var, rep); - CHECK_EQ(var_to_constraints_[var].size(), 1); - CHECK(var_to_constraints_[var].contains(kAffineRelationConstraint)); - CHECK(var_to_constraints_[rep].contains(kAffineRelationConstraint)); + CHECK_EQ(graph_.VarToConstraints(var).size(), 1); + CHECK(graph_.VarToConstraints(var).contains(kAffineRelationConstraint)); + CHECK(graph_.VarToConstraints(rep).contains(kAffineRelationConstraint)); - // We shouldn't reuse this variable again! - MarkVariableAsRemoved(var); - - // We do not call EraseFromVarToConstraint() on purpose here since the - // variable is removed. - var_to_constraints_[var].erase(kAffineRelationConstraint); + graph_.RemoveFromAffineRelation(var); affine_relations_.IgnoreFromClassSize(var); + // We shouldn't reuse this variable again! + graph_.MarkVariableAsRemoved(var); + // If the representative is left alone, we can remove it from the special // affine relation constraint too. if (affine_relations_.ClassSize(rep) == 1) { - EraseFromVarToConstraint(rep, kAffineRelationConstraint); + graph_.RemoveFromAffineRelation(rep); } if (VLOG_IS_ON(2)) { @@ -1047,7 +976,6 @@ void PresolveContext::CanonicalizeVariable(int ref) { const int new_var = NewIntVar(DomainOf(var).AdditionWith(Domain(-min))); CHECK(StoreAffineRelation(var, new_var, 1, min, /*debug_no_recursion=*/true)); UpdateRuleStats("variables: canonicalize domain"); - UpdateNewConstraintsVariableUsage(); } bool ScaleFloatingPointObjective(const SatParameters& params, @@ -1120,14 +1048,13 @@ bool PresolveContext::CanonicalizeAffineVariable(int ref, int64_t coeff, // at the end of the presolve. const int64_t min_value = new_domain.Min(); const int new_var = NewIntVar(new_domain.AdditionWith(Domain(-min_value))); - if (!working_model->variables(var).name().empty()) { - working_model->mutable_variables(new_var)->set_name( - working_model->variables(var).name()); + if (!working_model_->variables(var).name().empty()) { + working_model_->mutable_variables(new_var)->set_name( + working_model_->variables(var).name()); } CHECK(StoreAffineRelation(var, new_var, mod, offset + mod * min_value, /*debug_no_recursion=*/true)); UpdateRuleStats("variables: canonicalize affine domain"); - UpdateNewConstraintsVariableUsage(); return true; } @@ -1305,8 +1232,8 @@ bool PresolveContext::StoreAffineRelation(int var_x, int var_y, int64_t coeff, if (x != rep) modified_domains.Set(x); if (y != rep) modified_domains.Set(y); - var_to_constraints_[x].insert(kAffineRelationConstraint); - var_to_constraints_[y].insert(kAffineRelationConstraint); + graph_.AddToAffineRelation(x); + graph_.AddToAffineRelation(y); return true; } @@ -1314,8 +1241,8 @@ ABSL_MUST_USE_RESULT bool PresolveContext::StoreBooleanEqualityRelation( int ref_a, int ref_b) { if (is_unsat_) return false; - CHECK(!VariableWasRemoved(PositiveRef(ref_a))); - CHECK(!VariableWasRemoved(PositiveRef(ref_b))); + CHECK(!graph_.VariableWasRemoved(PositiveRef(ref_a))); + CHECK(!graph_.VariableWasRemoved(PositiveRef(ref_b))); CHECK(!DomainOf(PositiveRef(ref_a)).IsEmpty()); CHECK(!DomainOf(PositiveRef(ref_b)).IsEmpty()); CHECK(CanBeUsedAsLiteral(ref_a)); @@ -1395,24 +1322,23 @@ std::string PresolveContext::AffineRelationDebugString(int ref) const { void PresolveContext::ResetAfterCopy() { domains_.clear(); modified_domains.ResetAllToFalse(); - var_with_reduced_small_degree.ResetAllToFalse(); - var_to_constraints_.clear(); - var_to_num_linear1_.clear(); + + graph_.Reset(); + objective_map_.clear(); - DCHECK(!solution_crush_.SolutionIsLoaded()); + if (solution_crush_.SolutionIsLoaded()) { + solution_crush_.StoreSolutionAsHint(*working_model_); + } } -// Create the internal structure for any new variables in working_model. +// Create the internal structure for any new variables in working_model_. void PresolveContext::InitializeNewDomains() { - const int new_size = working_model->variables().size(); + const int new_size = working_model_->variables().size(); DCHECK_GE(new_size, domains_.size()); if (domains_.size() == new_size) return; modified_domains.Resize(new_size); - var_with_reduced_small_degree.Resize(new_size); - var_to_constraints_.resize(new_size); - var_to_num_linear1_.resize(new_size); - var_was_removed_.resize(new_size); + graph_.IncreaseNumVars(new_size); // We mark the domain as modified so we will look at these new variable during // our presolve loop. @@ -1420,7 +1346,7 @@ void PresolveContext::InitializeNewDomains() { domains_.resize(new_size); for (int i = old_size; i < new_size; ++i) { modified_domains.Set(i); - domains_[i] = ReadDomainFromProto(working_model->variables(i)); + domains_[i] = ReadDomainFromProto(working_model_->variables(i)); if (domains_[i].IsEmpty()) { is_unsat_ = true; return; @@ -1431,10 +1357,10 @@ void PresolveContext::InitializeNewDomains() { solution_crush_.Resize(new_size); } -void PresolveContext::LoadSolutionHint() { - const int num_vars = working_model->variables().size(); - if (working_model->has_solution_hint() || num_vars == 0) { - const auto hint_proto = working_model->solution_hint(); +void PresolveContext::LoadAndClampSolutionHint() { + const int num_vars = working_model_->variables().size(); + if (working_model_->has_solution_hint() || num_vars == 0) { + const auto hint_proto = working_model_->solution_hint(); absl::flat_hash_map hint_values; int num_changes = 0; for (int i = 0; i < hint_proto.vars().size(); ++i) { @@ -1459,6 +1385,28 @@ void PresolveContext::LoadSolutionHint() { } } +void PresolveContext::WriteHintToProto() { + if (ModelIsUnsat()) return; + + if (!solution_crush_.SolutionIsLoaded()) return; + const int num_vars = NumVariables(); + for (int i = 0; i < num_vars; ++i) { + // If the initial hint is incomplete or infeasible, the crushed hint might + // contain values outside of their respective domains (see SolutionCrush). + solution_crush_.SetOrUpdateVarToDomain(i, DomainOf(i)); + } + // If the time limit is reached, the presolved model might still contain + // non-representative "affine" variables. + for (int i = 0; i < num_vars; ++i) { + const auto relation = GetAffineRelation(i); + if (relation.representative != i) { + solution_crush_.SetVarToLinearExpression( + i, {{relation.representative, relation.coeff}}, relation.offset); + } + } + solution_crush_.StoreSolutionAsHint(*working_model_); +} + void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { CHECK(RefIsPositive(var)); CHECK_EQ(DomainOf(var).Size(), 2); @@ -1473,7 +1421,7 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { auto min_it = var_map.find(var_min); if (min_it != var_map.end()) { const int old_var = PositiveRef(min_it->second.Get(this)); - if (var_was_removed_[old_var]) { + if (graph_.VariableWasRemoved(old_var)) { var_map.erase(min_it); min_it = var_map.end(); } @@ -1483,7 +1431,7 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { auto max_it = var_map.find(var_max); if (max_it != var_map.end()) { const int old_var = PositiveRef(max_it->second.Get(this)); - if (var_was_removed_[old_var]) { + if (graph_.VariableWasRemoved(old_var)) { var_map.erase(max_it); max_it = var_map.end(); } @@ -1551,8 +1499,8 @@ bool PresolveContext::InsertVarValueEncodingInternal(int literal, int var, int64_t value, bool add_constraints) { DCHECK(RefIsPositive(var)); - DCHECK(!VariableWasRemoved(PositiveRef(literal))); - DCHECK(!VariableWasRemoved(var)); + DCHECK(!graph_.VariableWasRemoved(PositiveRef(literal))); + DCHECK(!graph_.VariableWasRemoved(var)); if (is_unsat_) return false; absl::flat_hash_map& var_map = encoding_[var]; @@ -1580,7 +1528,7 @@ bool PresolveContext::InsertVarValueEncodingInternal(int literal, int var, // radiation_m18_12_05_sat.fzn. The value was encoded, but maybe we never // used the involved variables / constraints, so it was removed (with the // encoding constraints) from the model already! We have to be careful. - if (VariableWasRemoved(PositiveRef(previous_literal))) { + if (graph_.VariableWasRemoved(PositiveRef(previous_literal))) { it->second = SavedLiteral(literal); } else { if (literal != previous_literal) { @@ -1751,7 +1699,7 @@ bool PresolveContext::StoreLiteralImpliesVarNeValue(int literal, int var, bool PresolveContext::HasVarValueEncoding(int ref, int64_t value, int* literal) { - CHECK(!VariableWasRemoved(PositiveRef(ref))); + CHECK(!graph_.VariableWasRemoved(PositiveRef(ref))); // TODO(user): do instead a DCHECK(VariableIsAffineRepresentative(ref)) if (!CanonicalizeEncoding(&ref, &value)) return false; DCHECK(RefIsPositive(ref)); @@ -1769,7 +1717,7 @@ bool PresolveContext::HasVarValueEncoding(int ref, int64_t value, const auto it = first_it->second.find(value); if (it == first_it->second.end()) return false; - if (VariableWasRemoved(it->second.Get(this))) return false; + if (graph_.VariableWasRemoved(it->second.Get(this))) return false; if (literal != nullptr) { *literal = it->second.Get(this); } @@ -1813,7 +1761,7 @@ int64_t PresolveContext::GetValueEncodingSize(int ref) const { } int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) { - CHECK(!VariableWasRemoved(ref)); + CHECK(!graph_.VariableWasRemoved(ref)); if (!CanonicalizeEncoding(&ref, &value)) return GetFalseLiteral(); // Positive after CanonicalizeEncoding(). @@ -1834,7 +1782,7 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) { auto it = var_map.find(value); if (it != var_map.end()) { const int lit = it->second.Get(this); - if (VariableWasRemoved(PositiveRef(lit))) { + if (graph_.VariableWasRemoved(PositiveRef(lit))) { // If the variable was already removed, for now we create a new one. // This should be rare hopefully. var_map.erase(value); @@ -1859,7 +1807,7 @@ int PresolveContext::GetOrCreateVarValueEncoding(int ref, int64_t value) { auto other_it = var_map.find(other_value); if (other_it != var_map.end()) { const int literal = NegatedRef(other_it->second.Get(this)); - if (VariableWasRemoved(PositiveRef(literal))) { + if (graph_.VariableWasRemoved(PositiveRef(literal))) { // If the variable was already removed, for now we create a new one. // This should be rare hopefully. var_map.erase(other_value); @@ -1913,7 +1861,7 @@ int PresolveContext::GetOrCreateAffineValueEncoding( } void PresolveContext::ReadObjectiveFromProto() { - const CpObjectiveProto& obj = working_model->objective(); + const CpObjectiveProto& obj = working_model_->objective(); // We do some small canonicalization here objective_proto_is_up_to_date_ = false; @@ -1973,7 +1921,7 @@ void PresolveContext::ReadObjectiveFromProto() { if (objective_map_[var] == 0) { RemoveVariableFromObjective(var); } else { - var_to_constraints_[var].insert(kObjectiveConstraint); + graph_.AddToObjective(var); } } @@ -1982,6 +1930,23 @@ void PresolveContext::ReadObjectiveFromProto() { } } +bool PresolveContext::RestrictObjectiveDomain(Domain domain) { + if (objective_domain_.IsIncludedIn(domain)) return true; + + UpdateRuleStats("objective: restricted objective domain"); + objective_domain_ = objective_domain_.IntersectionWith(domain); + + // This is defensive, we will recompute it on canonicalize. + objective_domain_is_constraining_ = true; + objective_proto_is_up_to_date_ = false; + if (objective_domain_.IsEmpty()) { + return NotifyThatModelIsUnsat( + "Constraint parallel to the objective makes the objective domain " + "empty"); + } + return true; +} + bool PresolveContext::CanonicalizeOneObjectiveVariable(int var) { const auto it = objective_map_.find(var); if (it == objective_map_.end()) return true; @@ -1993,9 +1958,8 @@ bool PresolveContext::CanonicalizeOneObjectiveVariable(int var) { if (params_.cp_model_presolve() && !params_.keep_all_feasible_solutions_in_presolve() && !objective_domain_is_constraining_ && - ConstraintVariableGraphIsUpToDate() && - var_to_constraints_[var].size() == 1 && - var_to_constraints_[var].contains(kObjectiveConstraint)) { + graph_.VarToConstraints(var).size() == 1 && + graph_.VarToConstraints(var).contains(kObjectiveConstraint)) { UpdateRuleStats("objective: variable not used elsewhere"); if (coeff > 0) { if (!IntersectDomainWith(var, Domain(MinOf(var)))) { @@ -2031,7 +1995,7 @@ bool PresolveContext::CanonicalizeOneObjectiveVariable(int var) { if (new_coeff == 0) { RemoveVariableFromObjective(r.representative); } else { - var_to_constraints_[r.representative].insert(kObjectiveConstraint); + graph_.AddToObjective(r.representative); if (IsFixed(r.representative)) { RemoveVariableFromObjective(r.representative); AddToObjectiveOffset(new_coeff * MinOf(r.representative)); @@ -2132,8 +2096,7 @@ bool PresolveContext::CanonicalizeObjective(bool simplify_domain) { // value without any issues. objective_domain_is_constraining_ = !implied_domain - .IntersectionWith(Domain(std::numeric_limits::min(), - objective_domain_.Max())) + .IntersectionWith(Domain(kint64min, objective_domain_.Max())) .IsIncludedIn(objective_domain_); if (objective_domain_is_constraining_) { VLOG(3) << "objective domain is constraining: size: " @@ -2166,7 +2129,7 @@ void PresolveContext::RemoveVariableFromObjective(int ref) { objective_proto_is_up_to_date_ = false; const int var = PositiveRef(ref); objective_map_.erase(var); - EraseFromVarToConstraint(var, kObjectiveConstraint); + graph_.RemoveFromObjective(var); } void PresolveContext::AddToObjective(int var, int64_t value) { @@ -2177,7 +2140,7 @@ void PresolveContext::AddToObjective(int var, int64_t value) { if (map_ref == 0) { RemoveVariableFromObjective(var); } else { - var_to_constraints_[var].insert(kObjectiveConstraint); + graph_.AddToObjective(var); } } @@ -2194,15 +2157,15 @@ void PresolveContext::AddLiteralToObjective(int ref, int64_t value) { if (map_ref == 0) { RemoveVariableFromObjective(var); } else { - var_to_constraints_[var].insert(kObjectiveConstraint); + graph_.AddToObjective(var); } } bool PresolveContext::AddToObjectiveOffset(int64_t delta) { objective_proto_is_up_to_date_ = false; const int64_t temp = CapAdd(objective_integer_before_offset_, delta); - if (temp == std::numeric_limits::min()) return false; - if (temp == std::numeric_limits::max()) return false; + if (temp == kint64min) return false; + if (temp == kint64max) return false; objective_integer_before_offset_ = temp; // Tricky: The objective domain is without the offset, so we need to shift it. @@ -2242,7 +2205,7 @@ bool PresolveContext::SubstituteVariableInObjective( std::abs(coeff_in_equality) * std::max(std::abs(MinOf(var_in_equality)), std::abs(MaxOf(var_in_equality)))); - if (new_value == std::numeric_limits::max()) return false; + if (new_value == kint64max) return false; objective_overflow_detection_ = new_value; // Compute the objective offset change. @@ -2272,7 +2235,7 @@ bool PresolveContext::SubstituteVariableInObjective( if (map_ref == 0) { RemoveVariableFromObjective(var); } else { - var_to_constraints_[var].insert(kObjectiveConstraint); + graph_.AddToObjective(var); } } @@ -2312,8 +2275,7 @@ bool PresolveContext::SubstituteVariableInObjective( // objective value without any issues. objective_domain_is_constraining_ = !implied_domain - .IntersectionWith(Domain(std::numeric_limits::min(), - objective_domain_.Max())) + .IntersectionWith(Domain(kint64min, objective_domain_.Max())) .IsIncludedIn(objective_domain_); if (objective_domain_is_constraining_) { VLOG(3) << "objective domain is constraining: size: " @@ -2333,7 +2295,7 @@ bool PresolveContext::ExploitExactlyOneInObjective( if (objective_map_.empty()) return false; if (exactly_one.empty()) return false; - int64_t min_coeff = std::numeric_limits::max(); + int64_t min_coeff = kint64max; for (const int ref : exactly_one) { const auto it = objective_map_.find(PositiveRef(ref)); if (it == objective_map_.end()) return false; @@ -2385,7 +2347,7 @@ bool PresolveContext::ShiftCostInExactlyOne(absl::Span exactly_one, // The value will be zero if it wasn't present. int64_t& map_ref = objective_map_[var]; if (map_ref == 0) { - var_to_constraints_[var].insert(kObjectiveConstraint); + graph_.AddToObjective(var); } if (RefIsPositive(ref)) { map_ref -= shift; @@ -2415,8 +2377,7 @@ bool PresolveContext::ShiftCostInExactlyOne(absl::Span exactly_one, // // TODO(user): This is a bit hacky, find a nicer way. if (!objective_domain_is_constraining_) { - objective_domain_ = - Domain(std::numeric_limits::min(), objective_domain_.Max()); + objective_domain_ = Domain(kint64min, objective_domain_.Max()); } return true; @@ -2437,7 +2398,7 @@ void PresolveContext::WriteObjectiveToProto() const { [](const std::pair& a, const std::pair& b) { return a.first < b.first; }); - CpObjectiveProto* mutable_obj = working_model->mutable_objective(); + CpObjectiveProto* mutable_obj = working_model_->mutable_objective(); mutable_obj->set_offset(objective_offset_); mutable_obj->set_scaling_factor(objective_scaling_factor_); mutable_obj->set_integer_before_offset(objective_integer_before_offset_); @@ -2457,8 +2418,8 @@ void PresolveContext::WriteObjectiveToProto() const { } void PresolveContext::WriteVariableDomainsToProto() const { - for (int i = 0; i < working_model->variables_size(); ++i) { - FillDomainInProto(DomainOf(i), working_model->mutable_variables(i)); + for (int i = 0; i < working_model_->variables_size(); ++i) { + FillDomainInProto(DomainOf(i), working_model_->mutable_variables(i)); } } @@ -2482,13 +2443,13 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( active_i, active_j); if (!IsFixed(time_i) && !IsFixed(time_j)) { - DCHECK(!PossibleIntegerOverflow(*working_model, + DCHECK(!PossibleIntegerOverflow(*working_model_, {time_i.vars(0), time_j.vars(0)}, {-time_i.coeffs(0), time_j.coeffs(0)})); } // result => (time_i <= time_j) && active_i && active_j. - ConstraintProto* const lesseq = working_model->add_constraints(); + ConstraintProto* const lesseq = working_model_->add_constraints(); lesseq->add_enforcement_literal(result); if (!IsFixed(time_i)) { lesseq->mutable_linear()->add_vars(time_i.vars(0)); @@ -2503,7 +2464,7 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( (IsFixed(time_i) ? FixedValue(time_i) : time_i.offset()) - (IsFixed(time_j) ? FixedValue(time_j) : time_j.offset()); lesseq->mutable_linear()->add_domain(offset); - lesseq->mutable_linear()->add_domain(std::numeric_limits::max()); + lesseq->mutable_linear()->add_domain(kint64max); CanonicalizeLinearConstraint(lesseq); if (!LiteralIsTrue(active_i)) { @@ -2515,7 +2476,7 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( // Not(result) && active_i && active_j => (time_i > time_j) { - ConstraintProto* const greater = working_model->add_constraints(); + ConstraintProto* const greater = working_model_->add_constraints(); if (!IsFixed(time_i)) { greater->mutable_linear()->add_vars(time_i.vars(0)); greater->mutable_linear()->add_coeffs(-time_i.coeffs(0)); @@ -2524,7 +2485,7 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( greater->mutable_linear()->add_vars(time_j.vars(0)); greater->mutable_linear()->add_coeffs(time_j.coeffs(0)); } - greater->mutable_linear()->add_domain(std::numeric_limits::min()); + greater->mutable_linear()->add_domain(kint64min); greater->mutable_linear()->add_domain(offset - 1); greater->add_enforcement_literal(NegatedRef(result)); @@ -2545,7 +2506,7 @@ int PresolveContext::GetOrCreateReifiedPrecedenceLiteral( const auto& rev_it = reified_precedences_cache_.find( GetReifiedPrecedenceKey(time_j, time_i, active_j, active_i)); if (rev_it != reified_precedences_cache_.end()) { - auto* const bool_or = working_model->add_constraints()->mutable_bool_or(); + auto* const bool_or = working_model_->add_constraints()->mutable_bool_or(); bool_or->add_literals(result); bool_or->add_literals(rev_it->second); if (!LiteralIsTrue(active_i)) { @@ -2563,11 +2524,9 @@ std::tuple PresolveContext::GetReifiedPrecedenceKey(const LinearExpressionProto& time_i, const LinearExpressionProto& time_j, int active_i, int active_j) { - const int var_i = - IsFixed(time_i) ? std::numeric_limits::min() : time_i.vars(0); + const int var_i = IsFixed(time_i) ? kint32min : time_i.vars(0); const int64_t coeff_i = IsFixed(time_i) ? 0 : time_i.coeffs(0); - const int var_j = - IsFixed(time_j) ? std::numeric_limits::min() : time_j.vars(0); + const int var_j = IsFixed(time_j) ? kint32min : time_j.vars(0); const int64_t coeff_j = IsFixed(time_j) ? 0 : time_j.coeffs(0); const int64_t offset = (IsFixed(time_i) ? FixedValue(time_i) : time_i.offset()) - @@ -2616,7 +2575,7 @@ bool LoadModelForProbing(PresolveContext* context, Model* local_model) { // Update the domain in the current CpModelProto. context->WriteVariableDomainsToProto(); - const CpModelProto& model_proto = *(context->working_model); + const CpModelProto& model_proto = context->WorkingModel(); // Adapt some of the parameters during this probing phase. SatParameters local_params = context->params(); local_params.set_use_implied_bounds(false); @@ -2630,7 +2589,9 @@ bool LoadModelForPresolve(const CpModelProto& model_proto, SatParameters params, *local_model->GetOrCreate() = std::move(params); local_model->GetOrCreate()->MergeWithGlobalTimeLimit( context->time_limit()); - local_model->Register(context->random()); + auto random = std::make_unique(context->random()); + local_model->Register(random.get()); + local_model->TakeOwnership(random.release()); auto* encoder = local_model->GetOrCreate(); encoder->DisableImplicationBetweenLiteral(); auto* mapping = local_model->GetOrCreate(); @@ -2852,7 +2813,7 @@ void CreateValidModelWithSingleConstraint(const ConstraintProto& ct, auto [it, inserted] = inverse_interval_map.insert({i, mini_model->constraints_size()}); if (inserted) { - const ConstraintProto& itv_ct = context->working_model->constraints(i); + const ConstraintProto& itv_ct = context->Constraint(i); *mini_model->add_constraints() = itv_ct; // Now add end = start + size for the interval. This is not strictly @@ -2928,8 +2889,8 @@ void CreateValidModelWithSingleConstraint(const ConstraintProto& ct, bool PresolveContext::DebugTestHintFeasibility() { WriteVariableDomainsToProto(); const absl::Span hint = solution_crush_.GetVarValues(); - if (hint.size() != working_model->variables().size()) return false; - return SolutionIsFeasible(*working_model, hint); + if (hint.size() != working_model_->variables().size()) return false; + return SolutionIsFeasible(*working_model_, hint); } } // namespace sat diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index 41795b3389c..2e801ab49d7 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -94,20 +94,188 @@ class SavedVariable { ABSL_MUST_USE_RESULT bool ScaleFloatingPointObjective( const SatParameters& params, SolverLogger* logger, CpModelProto* proto); +// Class responsible for maintaining the constraint <-> variable graph. +class LazyConstraintVariableGraph { + public: + LazyConstraintVariableGraph(const CpModelProto* cp_model, + const SolutionCrush* solution_crush) + : cp_model_(*cp_model), solution_crush_(*solution_crush) {} + + // This must be called when new variable are created, BEFORE any other + // functions here. + void IncreaseNumVars(int num_vars) { + DCHECK_GE(num_vars, var_to_constraints_.size()); + var_with_reduced_small_degree_.Resize(num_vars); + var_to_constraints_.resize(num_vars); + var_to_num_linear1_.resize(num_vars); + var_was_removed_.resize(num_vars); + } + + // This must be called each time a constraint is maybe mutated, BEFORE any + // other functions here. Note that new constraints are handled automatically + // and do not need this. + void MarkForLazyUpdate(int c) { + if (c < to_update_.size()) { + to_update_.Set(c); + } + } + + // Variable <-> constraint graph. + // The vector list is sorted and contains unique elements. + // + // Important: To properly handle the objective, var_to_constraints[objective] + // contains kObjectiveConstraint (i.e. -1) so that if the objective appear in + // only one constraint, the constraint cannot be simplified. + absl::Span ConstraintToVars(int c) { + LazyUpdate(); + return constraint_to_vars_[c]; + } + const absl::flat_hash_set& VarToConstraints(int var) { + LazyUpdate(); + return var_to_constraints_[var]; + } + int IntervalUsage(int c) { + LazyUpdate(); + if (c >= interval_usage_.size()) return 0; + return interval_usage_[c]; + } + int VarToNumLinear1(int var) { + LazyUpdate(); + return var_to_num_linear1_[var]; + } + int ConstraintToLinear1Var(int c) { + LazyUpdate(); + return constraint_to_linear1_var_[c]; + } + + SparseBitset* MutableVarWithReducedSmallDegree() { + LazyUpdate(); + return &var_with_reduced_small_degree_; + } + + void RegisterVariablesUsedInAssumptions() { + for (const int ref : cp_model_.assumptions()) { + var_to_constraints_[PositiveRef(ref)].insert(kAssumptionsConstraint); + } + } + + // Special constraints are still handled by the PresolveContext directly. + void AddToAffineRelation(int var) { + var_to_constraints_[var].insert(kAffineRelationConstraint); + } + void RemoveFromAffineRelation(int var) { + EraseFromVarToConstraint(var, kAffineRelationConstraint); + } + void AddToObjective(int var) { + var_to_constraints_[var].insert(kObjectiveConstraint); + } + void RemoveFromObjective(int var) { + EraseFromVarToConstraint(var, kObjectiveConstraint); + } + + // TODO(user): This is a bit hacky, fix. + void RemoveAllVariablesFromAffineRelationConstraint() { + for (auto& ref_map : var_to_constraints_) { + ref_map.erase(kAffineRelationConstraint); + } + } + + void Reset() { + to_update_.ClearAndResize(0); + constraint_to_vars_.clear(); + constraint_to_intervals_.clear(); + constraint_to_linear1_var_.clear(); + var_with_reduced_small_degree_.ResetAllToFalse(); + var_to_constraints_.clear(); + var_to_num_linear1_.clear(); + var_was_removed_.clear(); + } + + // Note that once a variable is removed, we DCHECK() that no constraint + // introduce it again. This should only be called when + // var_to_constraints_[var] is empty after LazyUpdate(), however we don't want + // to call LazyUpdate() to check that. + void MarkVariableAsRemoved(int var) { + DCHECK(RefIsPositive(var)); + var_was_removed_[var] = true; + } + + bool VariableWasRemoved(int var) const { + return var_was_removed_[PositiveRef(var)]; + } + + private: + void MaybeResizeIntervalData(); + void EraseFromVarToConstraint(int var, int c); + void AddVariableUsage(int c); + void UpdateLinear1Usage(const ConstraintProto& ct, int c); + void UpdateConstraintVariableUsage(int c); + + // This lazyly recompute what needs to be recomputed. + void LazyUpdate() { + const int old_size = constraint_to_vars_.size(); + const int new_size = cp_model_.constraints_size(); + DCHECK_LE(old_size, new_size); + CHECK_EQ(old_size, to_update_.size()); + + if (new_size > old_size) { + constraint_to_vars_.resize(new_size); + constraint_to_linear1_var_.resize(new_size, -1); + for (int c = old_size; c < new_size; ++c) { + AddVariableUsage(c); + } + } + + // Note that the update assume new constraints have already been handled, + // so the order is important here. + for (const int c : to_update_.PositionsSetAtLeastOnce()) { + UpdateConstraintVariableUsage(c); + } + to_update_.ClearAndResize(new_size); + } + + // This can change behind the class back, but is never mutated here. + const CpModelProto& cp_model_; + const SolutionCrush& solution_crush_; + + // The constraints that needs to be updated by LazyUpdate(). + // Note that any new constraint will also need update. + SparseBitset to_update_; + + // Constraints <-> Variables graph. + std::vector> constraint_to_vars_; + std::vector> var_to_constraints_; + + // Number of constraints of the form [lit =>] var in domain. + std::vector constraint_to_linear1_var_; + std::vector var_to_num_linear1_; + + // We maintain how many time each interval is used. + std::vector> constraint_to_intervals_; + std::vector interval_usage_; + + // This is mainly used for debug, to not reuse some cached variable that + // are not used anymore in the model. + std::vector var_was_removed_; + + // Each time the constraint <-> variable graph is updated, we update this. + // A variable is added here iff its usage decreased and is now one or two. + SparseBitset var_with_reduced_small_degree_; +}; + // Wrap the CpModelProto we are presolving with extra data structure like the // in-memory domain of each variables and the constraint variable graph. class PresolveContext { public: PresolveContext(Model* model, CpModelProto* cp_model, CpModelProto* mapping) - : working_model(cp_model), - mapping_model(mapping), + : mapping_model(mapping), + lrat_proof_handler(model->Mutable()), logger_(model->GetOrCreate()), params_(*model->GetOrCreate()), time_limit_(model->GetOrCreate()), - random_(model->GetOrCreate()) { - lrat_proof_handler = LratProofHandler::MaybeCreate( - model, /*enable_rat_proofs=*/params_.cp_model_pure_sat_presolve()); - } + random_(model->GetOrCreate()), + working_model_(cp_model), + graph_(cp_model, &solution_crush_) {} // Helpers to adds new variables to the presolved model. @@ -119,6 +287,12 @@ class PresolveContext { // WARNING: this does not set any hint value for the new variable. int NewBoolVar(absl::string_view source); + // Changes the name of a variable. This is just for debug, and we don't need + // to notify anyone that a name changed. + void SetVarName(int var, absl::string_view name) { + working_model_->mutable_variables(var)->set_name(name); + } + // Creates a new integer variable with the given domain and definition. // By default this also creates the linking constraint new_var = definition. // Its hint value is set to the value of the definition. Returns -1 if we @@ -141,11 +315,6 @@ class PresolveContext { int GetTrueLiteral(); int GetFalseLiteral(); - // Shortcuts to create enforced constraints. - ConstraintProto* AddEnforcedConstraint( - absl::Span enforcement_literals); - ConstraintProto* AddEnforcedConstraint(ConstraintProto* ct); - // a => b. void AddImplication(int a, int b); @@ -193,10 +362,6 @@ class PresolveContext { bool IsFixed(const LinearExpressionProto& expr) const; int64_t FixedValue(const LinearExpressionProto& expr) const; - // This is faster than testing IsFixed() + FixedValue(). - std::optional FixedValueOrNullopt( - const LinearExpressionProto& expr) const; - // Accepts any proto with two parallel vector .vars() and .coeffs(), like // LinearConstraintProto or ObjectiveProto or LinearExpressionProto but beware // that this ignore any offset. @@ -275,7 +440,9 @@ class PresolveContext { // Functions to make sure that once we remove a variable, we no longer reuse // it. void MarkVariableAsRemoved(int var); - bool VariableWasRemoved(int ref) const; + bool VariableWasRemoved(int var) const { + return graph_.VariableWasRemoved(var); + } // Same as VariableIsUniqueAndRemovable() except that in this case the // variable also appear in the objective in addition to a single constraint. @@ -333,20 +500,25 @@ class PresolveContext { // Updates the constraints <-> variables graph. This needs to be called each // time a constraint is modified. - void UpdateConstraintVariableUsage(int c); - - // At the beginning of the presolve, we delay the costly creation of this - // "graph" until we at least ran some basic presolve. This is because during - // a LNS neighborhood, many constraints will be reduced significantly by - // this "simple" presolve. - bool ConstraintVariableGraphIsUpToDate() const; + // + // TODO(user): Remove and call MarkForLazyUpdate() from MutableConstraint() + // instead. Note that this will require some refactoring to make sure we only + // call MutableConstraint() when we are about to change the constraint though, + // and not all the time. It is also still tricky because if one call + // MutableConstraint() just once, and use one of the function from this class, + // it would need to be called again later... + void UpdateConstraintVariableUsage(int c) { + if (is_unsat_) return; + graph_.MarkForLazyUpdate(c); + } - // Calls UpdateConstraintVariableUsage() on all newly created constraints. - void UpdateNewConstraintsVariableUsage(); + SparseBitset* MutableVarWithReducedSmallDegree() { + return graph_.MutableVarWithReducedSmallDegree(); + } // Returns true if our current constraints <-> variables graph is ok. // This is meant to be used in DEBUG mode only. - bool ConstraintVariableUsageIsConsistent(); + bool ConstraintVariableUsageIsConsistent() const; // Loop over all variable and return true if one of them is only used in // affine relation and is not a representative. This is in O(num_vars) and @@ -418,7 +590,7 @@ class PresolveContext { bool PropagateAffineRelation(int var); bool PropagateAffineRelation(int var, int rep, int64_t coeff, int64_t offset); - // Creates the internal structure for any new variables in working_model. + // Creates the internal structure for any new variables in working_model_. void InitializeNewDomains(); // This is a bit hacky. Clear some fields. See call site. @@ -438,10 +610,6 @@ class PresolveContext { // If an encoding already exists, it adds the two implications between // the previous encoding and the new encoding. // - // Important: This does not update the constraint<->variable graph, so - // ConstraintVariableGraphIsUpToDate() will be false until - // UpdateNewConstraintsVariableUsage() is called. - // // Returns false if the model become UNSAT. // // TODO(user): This function is not always correct if @@ -451,18 +619,10 @@ class PresolveContext { // Gets the associated literal if it is already created. Otherwise // create it, add the corresponding constraints and returns it. - // - // Important: This does not update the constraint<->variable graph, so - // ConstraintVariableGraphIsUpToDate() will be false until - // UpdateNewConstraintsVariableUsage() is called. int GetOrCreateVarValueEncoding(int ref, int64_t value); // Gets the associated literal if it is already created. Otherwise // create it, add the corresponding constraints and returns it. - // - // Important: This does not update the constraint<->variable graph, so - // ConstraintVariableGraphIsUpToDate() will be false until - // UpdateNewConstraintsVariableUsage() is called. int GetOrCreateAffineValueEncoding(const LinearExpressionProto& expr, int64_t value); @@ -529,6 +689,7 @@ class PresolveContext { // anything with that variable since it appear in at least two constraints. void ReadObjectiveFromProto(); bool AddToObjectiveOffset(int64_t delta); + ABSL_MUST_USE_RESULT bool RestrictObjectiveDomain(Domain domain); ABSL_MUST_USE_RESULT bool CanonicalizeOneObjectiveVariable(int var); ABSL_MUST_USE_RESULT bool CanonicalizeObjective(bool simplify_domain = true); void WriteObjectiveToProto() const; @@ -608,18 +769,12 @@ class PresolveContext { // contains kObjectiveConstraint (i.e. -1) so that if the objective appear in // only one constraint, the constraint cannot be simplified. absl::Span ConstraintToVars(int c) const { - DCHECK(ConstraintVariableGraphIsUpToDate()); - return constraint_to_vars_[c]; + return graph_.ConstraintToVars(c); } const absl::flat_hash_set& VarToConstraints(int var) const { - DCHECK(ConstraintVariableGraphIsUpToDate()); - return var_to_constraints_[var]; - } - int IntervalUsage(int c) const { - DCHECK(ConstraintVariableGraphIsUpToDate()); - if (c >= interval_usage_.size()) return 0; - return interval_usage_[c]; + return graph_.VarToConstraints(var); } + int IntervalUsage(int c) const { return graph_.IntervalUsage(c); } // Note this function does not update the constraint graph. It assumes this is // done elsewhere. @@ -636,9 +791,7 @@ class PresolveContext { // Make sure we never delete an "assumption" literal by using a special // constraint for that. void RegisterVariablesUsedInAssumptions() { - for (const int ref : working_model->assumptions()) { - var_to_constraints_[PositiveRef(ref)].insert(kAssumptionsConstraint); - } + graph_.RegisterVariablesUsedInAssumptions(); } // The "expansion" phase should be done once and allow to transform complex @@ -674,9 +827,13 @@ class PresolveContext { // Hint values outside the domain of their variable are adjusted to the // nearest value in this domain. Missing hint values are completed when // possible (e.g. for the model proto's fixed variables). - void LoadSolutionHint(); + // + // At the end of presolve, one should call WriteHintInProto() to update it. + void LoadAndClampSolutionHint(); + void WriteHintToProto(); SolutionCrush& solution_crush() { return solution_crush_; } + // This is slow O(problem_size) but can be used to debug presolve, either by // pinpointing the transition from feasible to infeasible or the other way // around if for some reason the presolve drop constraint that it shouldn't. @@ -685,14 +842,51 @@ class PresolveContext { SolverLogger* logger() const { return logger_; } const SatParameters& params() const { return params_; } TimeLimit* time_limit() { return time_limit_; } - ModelRandomGenerator* random() { return random_; } + absl::BitGenRef random() { return *random_; } + + // CpModelProto const accessors. + const CpModelProto& WorkingModel() const { return *working_model_; } + int NumConstraints() const { return working_model_->constraints().size(); } + int NumVariables() const { return working_model_->variables().size(); } + const ConstraintProto& Constraint(int c) const { + return working_model_->constraints(c); + } + + // Function to create a new constraint, with shortcuts for enforced ones. + ConstraintProto* AddConstraint() { return working_model_->add_constraints(); } + ConstraintProto* AddEnforcedConstraint( + absl::Span enforcement_literals); + ConstraintProto* AddEnforcedConstraint(const ConstraintProto& ct); + ConstraintProto* AddEnforcedConstraint(const ConstraintProto* ct); + + // CpModelProto mutable accessors. + ConstraintProto* MutableConstraint(int c) { + return working_model_->mutable_constraints(c); + } + void ClearConstraint(int c) { MutableConstraint(c)->Clear(); } + + // Sometimes we start creating a constraint but bail out, this is a "safe" + // pattern and shouldn't break invariants. + void RemoveLastConstraint() { + working_model_->mutable_constraints()->RemoveLast(); + } + + // WARNING. Only use when you know what you are doing as some modification + // might break the invariant maintained by this class. In particular, do not + // modify constraints via this pointer ! + // + // This is still exposed for efficiency and set-up in some places. The usage + // should stay minimal. + CpModelProto* UnsafeMutableWorkingModel() { return working_model_; } + SymmetryProto* MutableWorkingModelSymmetry() { + return working_model_->mutable_symmetry(); + } - CpModelProto* working_model = nullptr; CpModelProto* mapping_model = nullptr; // Used for the LRAT proof of inferred clauses during model copy and, if // applicable, during the pure SAT presolve. - std::unique_ptr lrat_proof_handler = nullptr; + LratProofHandler* lrat_proof_handler = nullptr; // Number of "rules" applied. This should be equal to the sum of all numbers // in stats_by_rule_name. This is used to decide if we should do one more pass @@ -710,10 +904,6 @@ class PresolveContext { // Each time a domain is modified this is set to true. SparseBitset modified_domains; - // Each time the constraint <-> variable graph is updated, we update this. - // A variable is added here iff its usage decreased and is now one or two. - SparseBitset var_with_reduced_small_degree; - // Advanced presolve. See this class comment. DomainDeductions deductions; @@ -727,16 +917,9 @@ class PresolveContext { absl::string_view file, int line); private: - void MaybeResizeIntervalData(); - - void EraseFromVarToConstraint(int var, int c); - // Helper to add an affine relation x = c.y + o to the given repository. bool AddRelation(int x, int y, int64_t c, int64_t o, AffineRelation* repo); - void AddVariableUsage(int c); - void UpdateLinear1Usage(const ConstraintProto& ct, int c); - // Makes sure we only insert encoding about the current representative. // // Returns false if ref cannot take the given value (it might not have been @@ -764,14 +947,22 @@ class PresolveContext { TimeLimit* time_limit_; ModelRandomGenerator* random_; + // The model we are modifying during presolve. + CpModelProto* working_model_ = nullptr; + // Initially false, and set to true on the first inconsistency. bool is_unsat_ = false; // The current domain of each variables. std::vector domains_; + // Used to maintain the hint during presolve. SolutionCrush solution_crush_; + // Store the variable <-> constraint graph on top of the working_model_. + // This is mutable for the lazy update. + mutable LazyConstraintVariableGraph graph_; + // Internal representation of the objective. During presolve, we first load // the objective in this format in order to have more efficient substitution // on large problems (also because the objective is often dense). At the end @@ -788,18 +979,6 @@ class PresolveContext { int64_t objective_integer_after_offset_; int64_t objective_integer_scaling_factor_; - // Constraints <-> Variables graph. - std::vector> constraint_to_vars_; - std::vector> var_to_constraints_; - - // Number of constraints of the form [lit =>] var in domain. - std::vector constraint_to_linear1_var_; - std::vector var_to_num_linear1_; - - // We maintain how many time each interval is used. - std::vector> constraint_to_intervals_; - std::vector interval_usage_; - // Used by GetTrueLiteral()/GetFalseLiteral(). bool true_literal_is_defined_ = false; int true_literal_; @@ -822,9 +1001,6 @@ class PresolveContext { // of the kAffineRelationConstraint. AffineRelation affine_relations_; - // Used by SetVariableAsRemoved() and VariableWasRemoved(). - std::vector var_was_removed_; - // Cache for the reified precedence literals created during the expansion of // the reservoir constraint. This cache is only valid during the expansion // phase, and is cleared afterwards. diff --git a/ortools/sat/presolve_context_test.cc b/ortools/sat/presolve_context_test.cc index 41163c3024d..578ba3ec31d 100644 --- a/ortools/sat/presolve_context_test.cc +++ b/ortools/sat/presolve_context_test.cc @@ -667,7 +667,6 @@ TEST(PresolveContextTest, VarIsOnlyUsedInEncoding) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); EXPECT_FALSE(context.VariableIsOnlyUsedInEncodingAndMaybeInObjective(0)); EXPECT_FALSE(context.VariableIsOnlyUsedInEncodingAndMaybeInObjective(1)); EXPECT_TRUE(context.VariableIsOnlyUsedInEncodingAndMaybeInObjective(2)); @@ -689,8 +688,7 @@ TEST(PresolveContextTest, ReifiedConstraintCache) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); LinearExpressionProto expr1; expr1.add_vars(2); expr1.add_coeffs(1); @@ -723,7 +721,6 @@ TEST(PresolveContextTest, ExploitFixedDomainOverflow) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); } TEST(PresolveContextTest, IntersectDomainWithConstant) { @@ -767,7 +764,7 @@ TEST(PresolveContextTest, IntersectDomainWithUpdatesHint) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); EXPECT_TRUE(context.IntersectDomainWith(0, Domain(5, 20))); @@ -815,7 +812,6 @@ TEST(PresolveContextTest, AddAffineRelation) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); EXPECT_TRUE(context.StoreAffineRelation(0, 1, 3, 0)); // x0 = 3x1 EXPECT_TRUE(context.StoreAffineRelation(2, 3, 5, 0)); // x2 = 5x3 @@ -850,7 +846,6 @@ TEST(PresolveContextTest, AddAffineRelationWithOffset) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); EXPECT_TRUE(context.StoreAffineRelation(0, 1, 3, 10)); // x0 = 3x1 + 10 EXPECT_TRUE(context.StoreAffineRelation(2, 3, 1, 30)); // x2 = x3 + 30 @@ -880,7 +875,6 @@ TEST(PresolveContextTest, AddAffineRelationPreventOverflow) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); // x0 = 10 x2 - 1e9. EXPECT_TRUE(context.StoreAffineRelation(0, 1, 10, -1000000000)); @@ -956,7 +950,6 @@ TEST(ExpressionIsALiteralTest, BasicApi) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.UpdateNewConstraintsVariableUsage(); int ref; const LinearExpressionProto expr1 = ParseTestProto(R"pb( @@ -1059,7 +1052,7 @@ TEST(PresolveContextTest, CanonicalizeLinearConstraint) { EXPECT_THAT(working_model.constraints(0), testing::EqualsProto(expected)); } -TEST(PresolveContextTest, LoadSolutionHint) { +TEST(PresolveContextTest, LoadAndClampSolutionHint) { Model model; CpModelProto working_model = ParseTestProto(R"pb( variables { domain: [ 0, 10 ] } @@ -1072,7 +1065,7 @@ TEST(PresolveContextTest, LoadSolutionHint) { )pb"); PresolveContext context(&model, &working_model, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); context.solution_crush().StoreSolutionAsHint(working_model); // All hints should be clamped to their respective domains, and new hints diff --git a/ortools/sat/presolve_encoding.cc b/ortools/sat/presolve_encoding.cc index 9e7abff5a29..d5e12cc5f26 100644 --- a/ortools/sat/presolve_encoding.cc +++ b/ortools/sat/presolve_encoding.cc @@ -32,6 +32,7 @@ #include "absl/strings/str_join.h" #include "google/protobuf/repeated_field.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/presolve_context.h" @@ -76,8 +77,8 @@ std::vector CreateVariableEncodingLocalModels( // Do a pass to gather all linear1 constraints. absl::flat_hash_map> var_to_linear1; - for (int i = 0; i < context->working_model->constraints_size(); ++i) { - const ConstraintProto& ct = context->working_model->constraints(i); + for (int i = 0; i < context->NumConstraints(); ++i) { + const ConstraintProto& ct = context->Constraint(i); if (ct.constraint_case() == ConstraintProto::kBoolOr || ct.constraint_case() == ConstraintProto::kAtMostOne || ct.constraint_case() == ConstraintProto::kExactlyOne) { @@ -106,12 +107,11 @@ std::vector CreateVariableEncodingLocalModels( // not related to the encodings. We use a bitset to keep track of all boolean // potentially encoding a domain for any variable and we filter out all // bool_or that are not linked to at least two of these booleans. - Bitset64 booleans_potentially_encoding_domain( - context->working_model->variables_size()); + Bitset64 booleans_potentially_encoding_domain(context->NumVariables()); for (const auto& [var, linear1_cts] : var_to_linear1) { for (const int c : linear1_cts) { - const ConstraintProto& ct = context->working_model->constraints(c); + const ConstraintProto& ct = context->Constraint(c); const int bool_var = PositiveRef(ct.enforcement_literal(0)); booleans_potentially_encoding_domain.Set(bool_var); bool_to_var_encodings[bool_var].push_back(var); @@ -124,7 +124,7 @@ std::vector CreateVariableEncodingLocalModels( int new_encoding_or_count = 0; for (int i = 0; i < constraint_encoding_or.size(); ++i) { const int c = constraint_encoding_or[i]; - const ConstraintProto& ct = context->working_model->constraints(c); + const ConstraintProto& ct = context->Constraint(c); const BoolArgumentProto& bool_ct = ct.constraint_case() == ConstraintProto::kAtMostOne ? ct.at_most_one() @@ -162,7 +162,7 @@ std::vector CreateVariableEncodingLocalModels( for (const int c : constraint_encoding_or) { var_to_bools.clear(); - const ConstraintProto& ct = context->working_model->constraints(c); + const ConstraintProto& ct = context->Constraint(c); const BoolArgumentProto& bool_ct = ct.constraint_case() == ConstraintProto::kAtMostOne ? ct.at_most_one() @@ -201,8 +201,8 @@ std::vector CreateVariableEncodingLocalModels( encoding_model.var_in_more_than_one_constraint_outside_the_local_model = (context->VarToConstraints(var).size() - linear1_cts.size() > 1); for (const int ct : linear1_cts) { - const int bool_var = PositiveRef( - context->working_model->constraints(ct).enforcement_literal(0)); + const int bool_var = + PositiveRef(context->Constraint(ct).enforcement_literal(0)); encoding_model.bools_only_used_inside_the_local_model.insert(bool_var); var_bool_counts[{var, bool_var}].linear1_count++; } @@ -233,7 +233,7 @@ bool BasicPresolveAndGetFullyEncodedDomains( // Fill ref_to_linear1 and do some basic presolving. Domain var_domain = context->DomainOf(local_model.var); for (const int ct : local_model.linear1_constraints) { - ConstraintProto* ct_proto = context->working_model->mutable_constraints(ct); + ConstraintProto* ct_proto = context->MutableConstraint(ct); DCHECK(ConstraintIsEncodingBound(*ct_proto)); const int ref = ct_proto->enforcement_literal(0); const Domain domain = ReadDomainFromProto(ct_proto->linear()); @@ -251,8 +251,7 @@ bool BasicPresolveAndGetFullyEncodedDomains( auto [it, inserted] = ref_to_linear1.insert({ref, ct}); if (!inserted) { *changed = true; - ConstraintProto* old_ct_proto = - context->working_model->mutable_constraints(it->second); + ConstraintProto* old_ct_proto = context->MutableConstraint(it->second); const Domain old_ct_domain = ReadDomainFromProto(old_ct_proto->linear()); const Domain new_domain = domain.IntersectionWith(old_ct_domain); ct_proto->Clear(); @@ -279,7 +278,7 @@ bool BasicPresolveAndGetFullyEncodedDomains( int new_linear1_size = 0; for (int i = 0; i < local_model.linear1_constraints.size(); ++i) { const int ct = local_model.linear1_constraints[i]; - const ConstraintProto& ct_proto = context->working_model->constraints(ct); + const ConstraintProto& ct_proto = context->Constraint(ct); if (ct_proto.constraint_case() != ConstraintProto::kLinear) continue; if (context->IsFixed(ct_proto.enforcement_literal(0))) { continue; @@ -297,10 +296,8 @@ bool BasicPresolveAndGetFullyEncodedDomains( for (const auto& [ref, ct] : ref_to_linear1) { auto it = ref_to_linear1.find(NegatedRef(ref)); if (it == ref_to_linear1.end()) continue; - const ConstraintProto& positive_ct = - context->working_model->constraints(ct); - const ConstraintProto& negative_ct = - context->working_model->constraints(it->second); + const ConstraintProto& positive_ct = context->Constraint(ct); + const ConstraintProto& negative_ct = context->Constraint(it->second); const Domain positive_domain = ReadDomainFromProto(positive_ct.linear()); const Domain negative_domain = ReadDomainFromProto(negative_ct.linear()); // b => x in D1 @@ -341,7 +338,7 @@ bool BasicPresolveAndGetFullyEncodedDomains( // // where D1, D2, ..., D_n are non overlapping. This works too for exactly_one. for (const int ct : local_model.constraints_linking_two_encoding_booleans) { - const ConstraintProto& ct_proto = context->working_model->constraints(ct); + const ConstraintProto& ct_proto = context->Constraint(ct); if (ct_proto.constraint_case() != ConstraintProto::kBoolOr && ct_proto.constraint_case() != ConstraintProto::kExactlyOne) { continue; @@ -361,8 +358,8 @@ bool BasicPresolveAndGetFullyEncodedDomains( encoding_detected = false; break; } - const Domain domain = ReadDomainFromProto( - context->working_model->constraints(it->second).linear()); + const Domain domain = + ReadDomainFromProto(context->Constraint(it->second).linear()); ref_and_domains.push_back({ref, domain}); if (!non_overlapping_domain.IntersectionWith(domain).IsEmpty()) { encoding_detected = false; @@ -395,8 +392,8 @@ bool BasicPresolveAndGetFullyEncodedDomains( "domain"); std::vector new_enforcement_literals(bool_or.literals().begin(), bool_or.literals().end()); - context->working_model->mutable_constraints(ct)->clear_bool_or(); - context->working_model->mutable_constraints(ct) + context->MutableConstraint(ct)->clear_bool_or(); + context->MutableConstraint(ct) ->mutable_exactly_one() ->mutable_literals() ->Add(new_enforcement_literals.begin(), @@ -428,7 +425,7 @@ bool BasicPresolveAndGetFullyEncodedDomains( // the linear1 holding the ~l0 encoding and the implicit encoding for // ~l0 coming from the bool_or are equivalent. ConstraintProto& negated_linear1_ct = - *context->working_model->mutable_constraints(it->second); + *context->MutableConstraint(it->second); const Domain negated_ct_domain = ReadDomainFromProto(negated_linear1_ct.linear()); const Domain potential_new_domain = @@ -455,7 +452,7 @@ bool DetectEncodedComplexDomain( PresolveContext* context, int ct_index, VariableEncodingLocalModel& local_model, absl::flat_hash_map* fully_encoded_domains, bool* changed) { - ConstraintProto* ct = context->working_model->mutable_constraints(ct_index); + ConstraintProto* ct = context->MutableConstraint(ct_index); *changed = false; if (context->ModelIsUnsat()) return false; @@ -718,23 +715,20 @@ bool DetectEncodedComplexDomain( } else if (domain_new_var_complement.IsEmpty()) { CHECK(context->SetLiteralToTrue(new_var)); } else { - local_model.linear1_constraints.push_back( - context->working_model->constraints_size()); - ConstraintProto* new_ct = context->working_model->add_constraints(); + local_model.linear1_constraints.push_back(context->NumConstraints()); + ConstraintProto* new_ct = context->AddConstraint(); new_ct->add_enforcement_literal(new_var); new_ct->mutable_linear()->add_vars(local_model.var); new_ct->mutable_linear()->add_coeffs(1); FillDomainInProto(domain_new_var, new_ct->mutable_linear()); - local_model.linear1_constraints.push_back( - context->working_model->constraints_size()); + local_model.linear1_constraints.push_back(context->NumConstraints()); local_model.bools_only_used_inside_the_local_model.insert( PositiveRef(new_var)); - new_ct = context->working_model->add_constraints(); + new_ct = context->AddConstraint(); new_ct->add_enforcement_literal(NegatedRef(new_var)); new_ct->mutable_linear()->add_vars(local_model.var); new_ct->mutable_linear()->add_coeffs(1); FillDomainInProto(domain_new_var_complement, new_ct->mutable_linear()); - context->UpdateNewConstraintsVariableUsage(); fully_encoded_domains->insert({new_var, domain_new_var}); fully_encoded_domains->insert( {NegatedRef(new_var), domain_new_var_complement}); @@ -779,7 +773,7 @@ bool DetectEncodedComplexDomain( int new_linear1_size = 0; for (int i = 0; i < local_model.linear1_constraints.size(); ++i) { const int ct = local_model.linear1_constraints[i]; - ConstraintProto* ct_proto = context->working_model->mutable_constraints(ct); + ConstraintProto* ct_proto = context->MutableConstraint(ct); if (bools_to_remove_set.contains( PositiveRef(ct_proto->enforcement_literal(0)))) { ct_proto->Clear(); @@ -832,7 +826,7 @@ bool MaybeTransferLinear1ToAnotherVariable( // In general constraint with more than two variable can't be removed. // Similarly for linear2 with non-fixed rhs as we would need to check the form // of all implied domain. - const auto& other_ct = context->working_model->constraints(other_c); + const auto& other_ct = context->Constraint(other_c); if (context->ConstraintToVars(other_c).size() != 2 || !other_ct.enforcement_literal().empty() || other_ct.constraint_case() == ConstraintProto::kLinear) { @@ -862,8 +856,7 @@ bool MaybeTransferLinear1ToAnotherVariable( Domain target_domain = implied.ContinuousMultiplicationBy(target.coeffs(0)) .AdditionWith(Domain(target.offset())); - target_domain = target_domain.IntersectionWith( - Domain(0, std::numeric_limits::max())); + target_domain = target_domain.IntersectionWith(Domain(0, kint64max)); // We have target = abs(expr). const Domain expr_domain = @@ -883,7 +876,7 @@ bool MaybeTransferLinear1ToAnotherVariable( // Applies transfer_f to all linear1. const Domain var_domain = context->DomainOf(var); for (const int c : to_rewrite) { - ConstraintProto* ct = context->working_model->mutable_constraints(c); + ConstraintProto* ct = context->MutableConstraint(c); if (ct->linear().vars(0) != var || ct->linear().coeffs(0) != 1) { // This shouldn't happen. LOG(INFO) << "Aborted in MaybeTransferLinear1ToAnotherVariable()"; @@ -913,7 +906,7 @@ bool MaybeTransferLinear1ToAnotherVariable( // Copy other_ct to the mapping model and delete var! context->NewMappingConstraint(other_ct, __FILE__, __LINE__); - context->working_model->mutable_constraints(other_c)->Clear(); + context->ClearConstraint(other_c); context->UpdateConstraintVariableUsage(other_c); context->MarkVariableAsRemoved(var); local_model.var = -1; diff --git a/ortools/sat/presolve_encoding_test.cc b/ortools/sat/presolve_encoding_test.cc index df16d0332a9..f718b54828a 100644 --- a/ortools/sat/presolve_encoding_test.cc +++ b/ortools/sat/presolve_encoding_test.cc @@ -52,7 +52,6 @@ TEST(CreateVariableEncodingLocalModelsTest, TrivialTest) { PresolveContext context(&model, &model_proto, &mapping_model); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); const std::vector local_models = CreateVariableEncodingLocalModels(&context); ASSERT_EQ(local_models.size(), 1); @@ -88,7 +87,6 @@ TEST(CreateVariableEncodingLocalModelsTest, BasicTest) { PresolveContext context(&model, &model_proto, &mapping_model); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); const std::vector local_models = CreateVariableEncodingLocalModels(&context); ASSERT_EQ(local_models.size(), 1); @@ -149,7 +147,6 @@ TEST(CreateVariableEncodingLocalModelsTest, OneBooleanUsedElsewhere) { PresolveContext context(&model, &model_proto, &mapping_model); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); const std::vector local_models = CreateVariableEncodingLocalModels(&context); ASSERT_EQ(local_models.size(), 1); @@ -228,7 +225,6 @@ TEST(CreateVariableEncodingLocalModelsTest, TwoVars) { PresolveContext context(&model, &model_proto, &mapping_model); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); std::vector local_models = CreateVariableEncodingLocalModels(&context); ASSERT_EQ(local_models.size(), 2); @@ -313,7 +309,6 @@ TEST(BasicPresolveAndGetFullyEncodedDomainsTest, EncodingWithBoolOr) { PresolveContext context(&model, &model_proto, &mapping_model); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); std::vector local_models = CreateVariableEncodingLocalModels(&context); @@ -330,7 +325,7 @@ TEST(BasicPresolveAndGetFullyEncodedDomainsTest, EncodingWithBoolOr) { ConstraintProto expected_exactly_one = ParseTestProto(R"pb( exactly_one { literals: [ 0, 1, 2 ] } )pb"); - EXPECT_THAT(context.working_model->constraints(), + EXPECT_THAT(context.WorkingModel().constraints(), testing::Contains(testing::EqualsProto(expected_exactly_one))); } @@ -400,7 +395,6 @@ TEST(DetectAllEncodedComplexDomainTest, BasicTest) { PresolveContext context(&model, &model_proto, &mapping_model); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); std::vector local_models = CreateVariableEncodingLocalModels(&context); ASSERT_TRUE(DetectAllEncodedComplexDomain(&context, local_models[0])); @@ -454,7 +448,7 @@ TEST(DetectAllEncodedComplexDomainTest, BasicTest) { coeffs: [ 2 ] } )pb"); - EXPECT_THAT(context.working_model, testing::EqualsProto(expected_model)); + EXPECT_THAT(context.WorkingModel(), testing::EqualsProto(expected_model)); } } // namespace diff --git a/ortools/sat/presolve_util.cc b/ortools/sat/presolve_util.cc index 8650ac468b4..0336b54e11f 100644 --- a/ortools/sat/presolve_util.cc +++ b/ortools/sat/presolve_util.cc @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -29,6 +28,7 @@ #include "absl/random/distributions.h" #include "absl/types/span.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/util.h" @@ -718,8 +718,8 @@ bool FindSingleLinearDifference(const LinearConstraintProto& lin1, int j = 0; while (i < size || j < size) { // Note that we can't have both undefined or the loop would have exited. - const int v1 = i < size ? lin1.vars(i) : std::numeric_limits::max(); - const int v2 = j < size ? lin2.vars(j) : std::numeric_limits::max(); + const int v1 = i < size ? lin1.vars(i) : kint32max; + const int v2 = j < size ? lin2.vars(j) : kint32max; // Same term, continue. if (v1 == v2 && lin1.coeffs(i) == lin2.coeffs(j)) { diff --git a/ortools/sat/primary_variables.cc b/ortools/sat/primary_variables.cc index f074bb6c165..1071ebedb07 100644 --- a/ortools/sat/primary_variables.cc +++ b/ortools/sat/primary_variables.cc @@ -26,6 +26,7 @@ #include "absl/container/btree_set.h" #include "absl/log/check.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/util/bitset.h" @@ -588,7 +589,7 @@ bool ComputeAllVariablesFromPrimaryVariables( (*solution)[var] = sum_of_free_variables / coeff_of_var; } break; case ConstraintProto::kLinMax: { - int64_t max = std::numeric_limits::min(); + int64_t max = kint64min; for (const auto& expr : ct.lin_max().exprs()) { int64_t expr_value = expr.offset(); for (int j = 0; j < expr.vars_size(); ++j) { diff --git a/ortools/sat/python/BUILD.bazel b/ortools/sat/python/BUILD.bazel index c288caea9e2..a203b7bab84 100644 --- a/ortools/sat/python/BUILD.bazel +++ b/ortools/sat/python/BUILD.bazel @@ -14,7 +14,7 @@ # Description: python wrapping of the C++ code at ../ load("@pip_deps//:requirements.bzl", "requirement") -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library") load("@rules_cc//cc:cc_binary.bzl", "cc_binary") load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_python//python:py_library.bzl", "py_library") @@ -26,11 +26,60 @@ cc_library( hdrs = ["linear_expr_doc.h"], ) +pybind_library( + name = "pybind_constraint", + srcs = ["pybind_constraint.cc"], + hdrs = ["pybind_constraint.h"], + linkstatic = 1, + deps = [ + ":linear_expr", + "//ortools/base:types", + "//ortools/port:proto_utils", + "//ortools/sat:cp_model_cc_proto", + "//ortools/sat:cp_model_utils", + "//ortools/util:saturated_arithmetic", + "//ortools/util:sorted_interval_list", + "@abseil-cpp//absl/container:flat_hash_map", + "@abseil-cpp//absl/log:check", + "@abseil-cpp//absl/strings", + ], +) + +pybind_library( + name = "pybind_linearexpr", + srcs = ["pybind_linearexpr.cc"], + hdrs = ["pybind_linearexpr.h"], + linkstatic = 1, + deps = [ + ":linear_expr", + ":linear_expr_doc", + "//ortools/base:types", + "//ortools/util:sorted_interval_list", + "@abseil-cpp//absl/base:dynamic_annotations", + "@abseil-cpp//absl/functional:any_invocable", + "@abseil-cpp//absl/strings", + ], +) + +pybind_library( + name = "pybind_solver", + srcs = ["pybind_solver.cc"], + hdrs = ["pybind_solver.h"], + linkstatic = 1, + deps = [ + ":linear_expr", + "//ortools/sat:cp_model_utils", + "//ortools/sat:swig_helper", + "@abseil-cpp//absl/strings", + ], +) + cc_library( name = "linear_expr", srcs = ["linear_expr.cc"], hdrs = ["linear_expr.h"], deps = [ + "//ortools/base:types", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:cp_model_utils", "//ortools/util:fp_roundtrip_conv", @@ -80,6 +129,9 @@ pybind_extension( ":linear_expr", ":linear_expr_doc", ":proto_builder_pybind11", + ":pybind_constraint", + ":pybind_linearexpr", + ":pybind_solver", "//ortools/port:proto_utils", "//ortools/sat:cp_model_cc_proto", "//ortools/sat:cp_model_utils", diff --git a/ortools/sat/python/CMakeLists.txt b/ortools/sat/python/CMakeLists.txt index 68a3f48c8cb..091206e5e9e 100644 --- a/ortools/sat/python/CMakeLists.txt +++ b/ortools/sat/python/CMakeLists.txt @@ -46,7 +46,47 @@ add_custom_command( COMMENT "Generate C++ proto_builder_pybind11.h" VERBATIM) +# Linear Expression doc +add_library(linear_expr_doc INTERFACE) +target_include_directories(linear_expr_doc INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) + +add_library(pybind_constraint OBJECT pybind_constraint.h pybind_constraint.cc) +set_target_properties(pybind_constraint PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(pybind_constraint PUBLIC + $ + $) +target_link_libraries(pybind_constraint PRIVATE + ${PROJECT_NAMESPACE}::ortools + absl::strings + pybind11::module) + +add_library(pybind11_linearexpr OBJECT pybind_linearexpr.h pybind_linearexpr.cc) +set_target_properties(pybind11_linearexpr PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(pybind11_linearexpr PUBLIC + $ + $) +target_link_libraries(pybind11_linearexpr PRIVATE + ${PROJECT_NAMESPACE}::ortools + linear_expr_doc + absl::strings + pybind11::module) + +add_library(pybind11_solver OBJECT pybind_solver.h pybind_solver.cc) +set_target_properties(pybind11_solver PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(pybind11_solver PUBLIC + $ + $) +target_link_libraries(pybind11_solver PRIVATE + ${PROJECT_NAMESPACE}::ortools + absl::strings + pybind11::module) + pybind11_add_module(cp_model_helper_pybind11 MODULE cp_model_helper.cc proto_builder_pybind11.h) +target_sources(cp_model_helper_pybind11 PRIVATE + $ + $ + $ +) set_target_properties(cp_model_helper_pybind11 PROPERTIES LIBRARY_OUTPUT_NAME "cp_model_helper") diff --git a/ortools/sat/python/cp_model_helper.cc b/ortools/sat/python/cp_model_helper.cc index 3efefa386ca..e7e716e9719 100644 --- a/ortools/sat/python/cp_model_helper.cc +++ b/ortools/sat/python/cp_model_helper.cc @@ -42,2293 +42,23 @@ #undef ValgrindSlowdown #include "ortools/port/proto_utils.h" // IWYU: keep #include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/cp_model_utils.h" -#include "ortools/sat/python/linear_expr.h" -#include "ortools/sat/python/linear_expr_doc.h" +#include "ortools/sat/python/pybind_constraint.h" +#include "ortools/sat/python/pybind_linearexpr.h" +#include "ortools/sat/python/pybind_solver.h" #include "ortools/sat/sat_parameters.pb.h" // IWYU: keep -#include "ortools/sat/swig_helper.h" -#include "ortools/util/saturated_arithmetic.h" -#include "ortools/util/sorted_interval_list.h" -#include "pybind11/attr.h" -#include "pybind11/cast.h" -#include "pybind11/functional.h" -#include "pybind11/gil.h" -#include "pybind11/numpy.h" #include "pybind11/pybind11.h" -#include "pybind11/pytypes.h" #include "pybind11/stl.h" namespace py = pybind11; namespace operations_research::sat::python { -void ThrowError(PyObject* py_exception, const std::string& message) { - PyErr_SetString(py_exception, message.c_str()); - throw py::error_already_set(); -} - -// We extend the SolverWrapper class to keep track of the local error already -// set. -class ExtSolveWrapper : public SolveWrapper { - public: - mutable std::optional local_error_already_set_; -}; - -// A trampoline class to override the OnSolutionCallback method to acquire the -// GIL. -class PySolutionCallback : public SolutionCallback { - public: - using SolutionCallback::SolutionCallback; /* Inherit constructors */ - void OnSolutionCallback() const override { - ::py::gil_scoped_acquire acquire; - try { - PYBIND11_OVERRIDE_PURE( - void, /* Return type */ - SolutionCallback, /* Parent class */ - OnSolutionCallback, /* Name of function */ - /* This function has no arguments. The trailing comma - in the previous line is needed for some compilers */ - ); - } catch (py::error_already_set& e) { - // We assume this code is serialized as the gil is held. - ExtSolveWrapper* solve_wrapper = static_cast(wrapper()); - if (!solve_wrapper->local_error_already_set_.has_value()) { - solve_wrapper->local_error_already_set_ = e; - } - StopSearch(); - } - } -}; - -class ResponseHelper { - public: - static bool BooleanValue(std::shared_ptr response, - std::shared_ptr lit) { - const int index = lit->index(); - if (index >= 0) { - return response->solution(index) != 0; - } else { - return response->solution(NegatedRef(index)) == 0; - } - } - - static bool FixedBooleanValue(std::shared_ptr response, - bool lit) { - return lit; - } - - static std::vector SufficientAssumptionsForInfeasibility( - std::shared_ptr response) { - return std::vector( - response->sufficient_assumptions_for_infeasibility().begin(), - response->sufficient_assumptions_for_infeasibility().end()); - } - - static double FloatValue(std::shared_ptr response, - std::shared_ptr expr) { - FloatExprVisitor visitor; - visitor.AddToProcess(expr, 1); - return visitor.Evaluate(*response); - } - - static double FixedFloatValue(std::shared_ptr response, - double value) { - return value; - } - - static int64_t Value(std::shared_ptr response, - std::shared_ptr expr) { - int64_t value; - IntExprVisitor visitor; - visitor.AddToProcess(expr, 1); - if (!visitor.Evaluate(*response, &value)) { - ThrowError(PyExc_ValueError, - absl::StrCat("Failed to evaluate linear expression: ", - expr->DebugString())); - } - return value; - } - - static int64_t FixedValue(std::shared_ptr response, - int64_t value) { - return value; - } -}; - -// Checks that the result is not null and throws an error if it is. -std::shared_ptr CheckBoundedLinearExpression( - std::shared_ptr result, - std::shared_ptr lhs, - std::shared_ptr rhs = nullptr) { - if (!result->ok()) { - if (rhs == nullptr) { - ThrowError(PyExc_TypeError, - absl::StrCat("Linear constraints only accept integer values " - "and coefficients: ", - lhs->DebugString())); - } else { - ThrowError(PyExc_TypeError, - absl::StrCat("Linear constraints only accept integer values " - "and coefficients: ", - lhs->DebugString(), " and ", rhs->DebugString())); - } - } - return result; -} - -void RaiseIfNone(std::shared_ptr expr) { - if (expr == nullptr) { - ThrowError(PyExc_TypeError, - "Linear constraints do not accept None as argument."); - } -} - -void ProcessExprArg( - const py::handle& arg, - absl::AnyInvocable)> on_linear_expr, - absl::AnyInvocable on_int_constant, - absl::AnyInvocable on_float_constant) { - if (py::isinstance(arg)) { - on_linear_expr(arg.cast>()); - } else if (py::isinstance(arg)) { - on_int_constant(arg.cast()); - } else if (py::isinstance(arg)) { - on_float_constant(arg.cast()); - } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer")) { - if (getattr(arg, "is_integer")().cast()) { - on_int_constant(arg.cast()); - } else { - on_float_constant(arg.cast()); - } - } else { - py::type objtype = py::type::of(arg); - const std::string type_name = objtype.attr("__name__").cast(); - ThrowError(PyExc_TypeError, - absl::StrCat("LinearExpr::sum() only accept linear " - "expressions and constants as argument: '", - absl::CEscape(type_name), "'")); - } -} - -void ProcessConstantArg(const py::handle& arg, - absl::AnyInvocable on_int_constant, - absl::AnyInvocable on_float_constant) { - if (py::isinstance(arg)) { - on_int_constant(arg.cast()); - } else if (py::isinstance(arg)) { - on_float_constant(arg.cast()); - } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer")) { - if (getattr(arg, "is_integer")().cast()) { - on_int_constant(arg.cast()); - } else { - on_float_constant(arg.cast()); - } - } else { - py::type objtype = py::type::of(arg); - const std::string type_name = objtype.attr("__name__").cast(); - ThrowError(PyExc_TypeError, - absl::StrCat("LinearExpr::weighted_sum() only accept constants " - "as coefficients: '", - absl::CEscape(type_name), "'")); - } -} - -std::shared_ptr SumArguments(py::args expressions) { - std::vector> linear_exprs; - int64_t int_offset = 0; - double float_offset = 0.0; - bool has_floats = false; - - const auto process_arg = [&](const py::handle& arg) -> void { - ProcessExprArg( - arg, - [&](std::shared_ptr expr) { linear_exprs.push_back(expr); }, - [&](int64_t value) { int_offset += value; }, - [&](double value) { - if (value != 0.0) { - float_offset += value; - has_floats = true; - } - }); - }; - - if (expressions.size() == 1 && py::isinstance(expressions[0])) { - // Normal list or tuple argument. - py::sequence elements = expressions[0].cast(); - linear_exprs.reserve(elements.size()); - for (const py::handle& expr : elements) { - process_arg(expr); - } - } else { // Direct sum(x, y, 3, ..) without []. - linear_exprs.reserve(expressions.size()); - for (const py::handle expr : expressions) { - process_arg(expr); - } - } - - // If there are floats, we add the int offset to the float offset. - if (has_floats) { - float_offset += static_cast(int_offset); - int_offset = 0; - } - - if (linear_exprs.empty()) { - if (has_floats) { - return std::make_shared(float_offset); - } else { - return std::make_shared(int_offset); - } - } else if (linear_exprs.size() == 1) { - if (has_floats) { - if (float_offset == 0.0) { - return linear_exprs[0]; - } else { - return std::make_shared(linear_exprs[0], 1.0, - float_offset); - } - } else if (int_offset != 0) { - return std::make_shared(linear_exprs[0], 1, int_offset); - } else { - return linear_exprs[0]; - } - } else { - if (has_floats) { - return std::make_shared(linear_exprs, 0, float_offset); - } else { - return std::make_shared(linear_exprs, int_offset, 0.0); - } - } -} - -std::shared_ptr WeightedSumArguments(py::sequence expressions, - py::sequence coefficients) { - const int64_t size = expressions.size(); - if (size != coefficients.size()) { - ThrowError(PyExc_ValueError, - absl::StrCat("LinearExpr::weighted_sum() requires the same " - "number of arguments and coefficients: ", - size, " != ", coefficients.size())); - } - - std::vector> linear_exprs; - std::vector int_coeffs; - std::vector float_coeffs; - linear_exprs.reserve(size); - int_coeffs.reserve(size); - float_coeffs.reserve(size); - int64_t int_offset = 0; - double float_offset = 0.0; - bool has_floats = false; - bool fast_coeffs = false; - const void* raw_coeffs = nullptr; - Py_ssize_t coeff_stride = 0; - - enum { kInt64, kInt32, kDouble } coeff_type = kInt64; - - if (py::isinstance(coefficients)) { - py::array arr = coefficients.cast(); - - if (arr.ndim() == 1 && arr.size() == size) { - if (py::isinstance>(arr)) { - fast_coeffs = true; - raw_coeffs = arr.data(); - coeff_stride = arr.strides(0); - coeff_type = kInt64; - } else if (py::isinstance>(arr)) { - fast_coeffs = true; - raw_coeffs = arr.data(); - coeff_stride = arr.strides(0); - coeff_type = kInt32; - } else if (py::isinstance>(arr)) { - fast_coeffs = true; - raw_coeffs = arr.data(); - coeff_stride = arr.strides(0); - coeff_type = kDouble; - } - } - } - - for (int64_t i = 0; i < size; ++i) { - // --- Parse Coefficient --- - int64_t c_int = 0; - double c_float = 0.0; - bool c_is_float = false; - bool c_is_zero = false; - if (fast_coeffs) { - const char* ptr = static_cast(raw_coeffs) + i * coeff_stride; - if (coeff_type == kInt64) { - c_int = *reinterpret_cast(ptr); - if (c_int == 0) { - c_is_zero = true; - } else { - c_float = static_cast(c_int); - } - } else if (coeff_type == kInt32) { - c_int = *reinterpret_cast(ptr); - if (c_int == 0) { - c_is_zero = true; - } else { - c_float = static_cast(c_int); - } - } else { // kDouble - c_float = *reinterpret_cast(ptr); - if (c_float == 0.0) { - c_is_zero = true; - } else { - c_is_float = true; - has_floats = true; - } - } - } else { - const py::handle coeff_obj = coefficients[i]; - if (py::isinstance(coeff_obj)) { - c_int = coeff_obj.cast(); - if (c_int == 0) { - c_is_zero = true; - } else { - c_float = static_cast(c_int); - } - } else if (py::isinstance(coeff_obj)) { - c_float = coeff_obj.cast(); - if (c_float == 0.0) { - c_is_zero = true; - } else { - c_is_float = true; - has_floats = true; - } - } else if (hasattr(coeff_obj, "dtype") && - hasattr(coeff_obj, "is_integer")) { - if (getattr(coeff_obj, "is_integer")().cast()) { - c_int = coeff_obj.cast(); - if (c_int == 0) { - c_is_zero = true; - } else { - c_float = static_cast(c_int); - } - } else { - c_float = coeff_obj.cast(); - if (c_float == 0.0) { - c_is_zero = true; - } else { - c_is_float = true; - has_floats = true; - } - } - } else { - py::type objtype = py::type::of(coeff_obj); - const std::string type_name = - objtype.attr("__name__").cast(); - ThrowError( - PyExc_TypeError, - absl::StrCat("LinearExpr::weighted_sum() only accept constants " - "as coefficients: '", - absl::CEscape(type_name), "'")); - } - } - - if (c_is_zero) continue; - - // --- Parse Expression --- - const py::handle expr_obj = expressions[i]; - if (py::isinstance(expr_obj)) { - linear_exprs.push_back(expr_obj.cast>()); - if (c_is_float) { - float_coeffs.push_back(c_float); - } else { - int_coeffs.push_back(c_int); - float_coeffs.push_back(c_float); - } - } else if (py::isinstance(expr_obj)) { - int64_t val = expr_obj.cast(); - if (val == 0) continue; - if (c_is_float) { - float_offset += c_float * static_cast(val); - } else { - int_offset += c_int * val; - } - } else if (py::isinstance(expr_obj)) { - double val = expr_obj.cast(); - if (val == 0.0) continue; - has_floats = true; - float_offset += c_float * val; - } else if (hasattr(expr_obj, "dtype") && hasattr(expr_obj, "is_integer")) { - if (getattr(expr_obj, "is_integer")().cast()) { - int64_t val = expr_obj.cast(); - if (val == 0) continue; - if (c_is_float) { - float_offset += c_float * static_cast(val); - } else { - int_offset += c_int * val; - } - } else { - double val = expr_obj.cast(); - if (val == 0.0) continue; - has_floats = true; - float_offset += c_float * val; - } - } else { - py::type objtype = py::type::of(expr_obj); - const std::string type_name = - objtype.attr("__name__").cast(); - ThrowError(PyExc_TypeError, - absl::StrCat("LinearExpr::weighted_sum() only accept linear " - "expressions and constants as argument: '", - absl::CEscape(type_name), "'")); - } - } - - // Correct the float offset if there are int offsets. - if (has_floats) { - float_offset += static_cast(int_offset); - int_offset = 0; - } - - if (linear_exprs.empty()) { - if (has_floats) { - return std::make_shared(float_offset); - } else { - return std::make_shared(int_offset); - } - } else if (linear_exprs.size() == 1) { - if (has_floats) { - return std::make_shared(linear_exprs[0], float_coeffs[0], - float_offset); - } else if (int_offset != 0 || int_coeffs[0] != 1) { - return std::make_shared(linear_exprs[0], int_coeffs[0], - int_offset); - } else { - return linear_exprs[0]; - } - } else { - if (has_floats) { - return std::make_shared(linear_exprs, float_coeffs, - float_offset); - } else { - return std::make_shared(linear_exprs, int_coeffs, - int_offset); - } - } -} - -void LinearExprToProto(const py::handle& arg, int64_t multiplier, - LinearExpressionProto* proto) { - proto->Clear(); - if (py::isinstance(arg)) { - std::shared_ptr expr = arg.cast>(); - IntExprVisitor visitor; - visitor.AddToProcess(expr, multiplier); - std::vector> vars; - std::vector coeffs; - int64_t offset = 0; - if (!visitor.Process(&vars, &coeffs, &offset)) { - ThrowError(PyExc_ValueError, - absl::StrCat("Failed to convert integer linear expression: ", - expr->DebugString())); - } - for (const auto& var : vars) { - proto->add_vars(var->index()); - } - for (const int64_t coeff : coeffs) { - proto->add_coeffs(coeff); - } - proto->set_offset(offset); - } else if (py::isinstance(arg)) { - int64_t value = arg.cast(); - proto->set_offset(value * multiplier); - } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer") && - getattr(arg, "is_integer")().cast()) { - int64_t value = arg.cast(); - proto->set_offset(value * multiplier); - } else { - py::type objtype = py::type::of(arg); - const std::string type_name = objtype.attr("__name__").cast(); - py::print(arg); - ThrowError(PyExc_TypeError, - absl::StrCat("Cannot convert '", absl::CEscape(type_name), - "' to a linear expression.")); - } -} - -class Constraint; -class IntervalVar; - -enum class BoolArgumentConstraint { - kAtMostOne, - kBoolAnd, - kBoolOr, - kBoolXor, - kExactlyOne, -}; - -enum class LinearArgumentConstraint { - kDiv, - kMax, - kMin, - kMod, - kProd, -}; - -class CpBaseModel : public std::enable_shared_from_this { - public: - explicit CpBaseModel(std::shared_ptr model_proto) - : model_proto_(model_proto == nullptr ? std::make_shared() - : model_proto), - numpy_bool_type_(py::dtype::of().attr("type").cast()) { - if (model_proto != nullptr) RebuildConstantMap(); - } - - std::shared_ptr model_proto() const { return model_proto_; } - - int GetOrMakeIndexFromConstant(int64_t value) { - auto it = cache_.find(value); - if (it != cache_.end()) return it->second; - const int index = model_proto_->variables_size(); - IntegerVariableProto* const_var = model_proto_->add_variables(); - const_var->add_domain(value); - const_var->add_domain(value); - cache_[value] = index; - return index; - } - - void RebuildConstantMap() { - cache_.clear(); - for (int i = 0; i < model_proto_->variables_size(); ++i) { - const IntegerVariableProto& var = model_proto_->variables(i); - if (var.domain_size() == 2 && var.domain(0) == var.domain(1) && - var.name().empty()) { // Constants do not have names. - cache_[var.domain(0)] = i; - } - } - } - - int GetOrMakeBooleanIndex(py::handle literal) { - if (py::isinstance(literal)) { - std::shared_ptr var = literal.cast>(); - AssertVariableIsBoolean(var); - return var->index(); - } else if (py::isinstance(literal)) { - std::shared_ptr not_var = - literal.cast>(); - AssertVariableIsBoolean(not_var); - return not_var->index(); - } else if (IsBooleanValue(literal)) { - const bool value = literal.cast(); - if (value) { - return GetOrMakeIndexFromConstant(1); - } else { - return GetOrMakeIndexFromConstant(0); - } - } else if (py::isinstance(literal)) { - const int64_t value = literal.cast(); - if (value == 1 || value == -1) { // -1 = ~False. - return GetOrMakeIndexFromConstant(1); - } - if (value == 0 || value == -2) { // -2 = ~True. - return GetOrMakeIndexFromConstant(0); - } - ThrowError(PyExc_TypeError, absl::StrCat("Invalid literal: ", value)); - } else { - py::type objtype = py::type::of(literal); - const std::string type_name = - objtype.attr("__name__").cast(); - ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: '", - absl::CEscape(type_name), "'")); - } - return 0; // Unreachable. - } - - int GetOrMakeVariableIndex(py::handle arg) { - if (py::isinstance(arg)) { - std::shared_ptr var = arg.cast>(); - return var->index(); - } else if (py::isinstance(arg)) { - return GetOrMakeIndexFromConstant(arg.cast()); - } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer") && - getattr(arg, "is_integer")().cast()) { - return GetOrMakeIndexFromConstant(arg.cast()); - } else { - py::type objtype = py::type::of(arg); - const std::string type_name = - objtype.attr("__name__").cast(); - ThrowError(PyExc_TypeError, - absl::StrCat("GetOrMakeVariableIndex() only accept integer " - "variables or constants as argument: '", - absl::CEscape(type_name), "'")); - } - return 0; // Unreachable. - } - - void AssertVariableIsBoolean(std::shared_ptr literal) { - if (PositiveRef(literal->index()) >= model_proto_->variables_size()) { - ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: ", - literal->ToString())); - } - IntegerVariableProto* var = - model_proto_->mutable_variables(PositiveRef(literal->index())); - if (var->domain_size() != 2 || var->domain(0) < 0 || var->domain(1) > 1) { - ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: ", - literal->ToString())); - } - } - - bool IsBooleanValue(py::handle value) { - return py::isinstance(value) || - py::isinstance(value, numpy_bool_type_); - } - - std::shared_ptr AddAllDifferentInternal(py::args exprs); - - std::shared_ptr AddAutomatonInternal( - py::sequence transition_expressions, int64_t starting_state, - const std::vector& final_states, - const std::vector>& transition_triples); - - std::shared_ptr AddBoolArgumentConstraintInternal( - BoolArgumentConstraint type, py::args literals); - - std::shared_ptr AddBoundedLinearExpressionInternal( - BoundedLinearExpression* ble); - - std::shared_ptr AddElementInternal(const py::handle& index, - py::sequence exprs, - const py::handle& target); - - std::shared_ptr AddInverseInternal(py::sequence direct, - py::sequence inverse); - - std::shared_ptr AddLinearArgumentConstraintInternal( - LinearArgumentConstraint type, const py::handle& target, py::args exprs); - - std::shared_ptr AddReservoirInternal(py::sequence times, - py::sequence level_changes, - py::sequence actives, - int64_t min_level, - int64_t max_level); - - std::shared_ptr AddTableInternal( - py::sequence exprs, const std::vector>& tuples, - bool negated); - - std::shared_ptr NewIntervalVarInternal(const std::string& name, - const py::handle& start, - const py::handle& size, - const py::handle& end, - py::sequence literals); - - std::shared_ptr AddNoOverlapInternal( - const std::vector>& intervals); - - std::shared_ptr AddNoOverlap2DInternal( - const std::vector>& x_intervals, - const std::vector>& y_intervals); - - std::shared_ptr AddCumulativeInternal( - const std::vector>& intervals, - py::sequence demands, const py::handle& capacity); - - std::shared_ptr AddCircuitInternal( - const std::vector>& arcs); - - std::shared_ptr AddRoutesInternal( - const std::vector>& arcs); - - private: - std::shared_ptr model_proto_; - absl::flat_hash_map cache_; - py::type numpy_bool_type_; -}; - -class Constraint { - public: - // We need to store the CpBaseModel to convert enforcement literals to - // indices. - Constraint(std::shared_ptr model, int index) - : model_(model), index_(index) {} - - int index() const { return index_; } - - std::shared_ptr model_proto() const { - return model_->model_proto(); - } - - ConstraintProto* proto() const { - return model_->model_proto()->mutable_constraints(index_); - } - - std::shared_ptr model() const { return model_; } - - std::string name() const { return proto()->name(); } - void SetName(const std::string& name) { proto()->set_name(name); } - void ClearName() { proto()->clear_name(); } - - std::string ToString() const { - return absl::StrCat("Constraint(index=", index_, ", ", - ProtobufDebugString(*proto()), ")"); - } - - private: - std::shared_ptr model_; - int index_; -}; - -std::shared_ptr CpBaseModel::AddAllDifferentInternal( - py::args exprs) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - if (exprs.size() == 1 && py::isinstance(exprs[0])) { - for (const auto& expr : exprs[0]) { - LinearExprToProto(expr, 1, ct->mutable_all_diff()->add_exprs()); - } - } else { - for (const auto& expr : exprs) { - LinearExprToProto(expr, 1, ct->mutable_all_diff()->add_exprs()); - } - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddAutomatonInternal( - py::sequence transition_expressions, int64_t starting_state, - const std::vector& final_states, - const std::vector>& transition_triples) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - for (const auto& expr : transition_expressions) { - LinearExprToProto(expr, 1, ct->mutable_automaton()->add_exprs()); - } - ct->mutable_automaton()->set_starting_state(starting_state); - ct->mutable_automaton()->mutable_final_states()->Add(final_states.begin(), - final_states.end()); - for (const auto& tuple : transition_triples) { - if (tuple.size() != 3) { - ThrowError(PyExc_ValueError, - absl::StrCat("transition (", absl::StrJoin(tuple, ","), - ") has the wrong arity != 3")); - } - ct->mutable_automaton()->add_transition_tail(tuple[0]); - ct->mutable_automaton()->add_transition_label(tuple[1]); - ct->mutable_automaton()->add_transition_head(tuple[2]); - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddBoolArgumentConstraintInternal( - BoolArgumentConstraint type, py::args literals) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - BoolArgumentProto* proto = nullptr; - switch (type) { - case BoolArgumentConstraint::kAtMostOne: - proto = ct->mutable_at_most_one(); - break; - case BoolArgumentConstraint::kBoolAnd: - proto = ct->mutable_bool_and(); - break; - case BoolArgumentConstraint::kBoolOr: - proto = ct->mutable_bool_or(); - break; - case BoolArgumentConstraint::kBoolXor: - proto = ct->mutable_bool_xor(); - break; - case BoolArgumentConstraint::kExactlyOne: - proto = ct->mutable_exactly_one(); - break; - default: - ThrowError(PyExc_ValueError, - absl::StrCat("Unknown boolean argument constraint: ", type)); - } - if (literals.size() == 1 && py::isinstance(literals[0])) { - for (const auto& literal : literals[0]) { - proto->add_literals(GetOrMakeBooleanIndex(literal)); - } - } else { - for (const auto& literal : literals) { - proto->add_literals(GetOrMakeBooleanIndex(literal)); - } - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddBoundedLinearExpressionInternal( - BoundedLinearExpression* ble) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - for (const auto& var : ble->vars()) { - ct->mutable_linear()->add_vars(var->index()); - } - for (const int64_t coeff : ble->coeffs()) { - ct->mutable_linear()->add_coeffs(coeff); - } - const int64_t offset = ble->offset(); - const Domain& bounds = ble->bounds(); - for (const int64_t bound : bounds.FlattenedIntervals()) { - if (bound == std::numeric_limits::min() || - bound == std::numeric_limits::max()) { - ct->mutable_linear()->add_domain(bound); - } else { - ct->mutable_linear()->add_domain(CapSub(bound, offset)); - } - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddElementInternal( - const py::handle& index, py::sequence exprs, const py::handle& target) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - LinearExprToProto(index, 1, ct->mutable_element()->mutable_linear_index()); - for (const auto& expr : exprs) { - LinearExprToProto(expr, 1, ct->mutable_element()->add_exprs()); - } - LinearExprToProto(target, 1, ct->mutable_element()->mutable_linear_target()); - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddInverseInternal( - py::sequence direct, py::sequence inverse) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - for (const auto& var : direct) { - ct->mutable_inverse()->add_f_direct(GetOrMakeVariableIndex(var)); - } - for (const auto& var : inverse) { - ct->mutable_inverse()->add_f_inverse(GetOrMakeVariableIndex(var)); - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddLinearArgumentConstraintInternal( - LinearArgumentConstraint type, const py::handle& target, py::args exprs) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - LinearArgumentProto* proto; - int64_t multiplier = 1; - switch (type) { - case LinearArgumentConstraint::kDiv: - proto = ct->mutable_int_div(); - break; - case LinearArgumentConstraint::kMax: - proto = ct->mutable_lin_max(); - break; - case LinearArgumentConstraint::kMin: - proto = ct->mutable_lin_max(); - multiplier = -1; - break; - case LinearArgumentConstraint::kMod: - proto = ct->mutable_int_mod(); - break; - case LinearArgumentConstraint::kProd: - proto = ct->mutable_int_prod(); - break; - default: - ThrowError(PyExc_ValueError, - absl::StrCat("Unknown integer argument constraint: ", type)); - } - - LinearExprToProto(target, multiplier, proto->mutable_target()); - - if (exprs.size() == 1 && py::isinstance(exprs[0])) { - for (const auto& expr : exprs[0]) { - LinearExprToProto(expr, multiplier, proto->add_exprs()); - } - } else { - for (const auto& expr : exprs) { - LinearExprToProto(expr, multiplier, proto->add_exprs()); - } - } - - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddReservoirInternal( - py::sequence times, py::sequence level_changes, py::sequence actives, - int64_t min_level, int64_t max_level) { - const int ct_index = model_proto_->constraints_size(); - ReservoirConstraintProto* proto = - model_proto_->add_constraints()->mutable_reservoir(); - for (const auto& time : times) { - LinearExprToProto(time, 1, proto->add_time_exprs()); - } - for (const auto& change : level_changes) { - LinearExprToProto(change, 1, proto->add_level_changes()); - } - for (const auto& active : actives) { - proto->add_active_literals(GetOrMakeBooleanIndex(active)); - } - proto->set_min_level(min_level); - proto->set_max_level(max_level); - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddTableInternal( - py::sequence exprs, const std::vector>& tuples, - bool negated) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - const int num_exprs = exprs.size(); - for (const auto& expr : exprs) { - LinearExprToProto(expr, 1, ct->mutable_table()->add_exprs()); - } - for (const auto& tuple : tuples) { - if (tuple.size() != num_exprs) { - ThrowError(PyExc_ValueError, - absl::StrCat("Tuple (", absl::StrJoin(tuple, ","), - ") has the wrong arity != ", num_exprs)); - } - ct->mutable_table()->mutable_values()->Add(tuple.begin(), tuple.end()); - } - ct->mutable_table()->set_negated(negated); - return std::make_shared(shared_from_this(), ct_index); -} - -std::string ShortName(int literal, std::shared_ptr model_proto) { - const int var = PositiveRef(literal); - const IntegerVariableProto& var_proto = model_proto->variables(var); - const std::string& var_name = - var_proto.name().empty() ? absl::StrCat("i", var) : var_proto.name(); - if (literal < 0) { - return absl::StrCat("not(", var_name, ")"); - } else { - return var_name; - } -} - -std::string ShortExprName(const LinearExpressionProto& expr, - std::shared_ptr model_proto) { - if (expr.vars().empty()) { - return absl::StrCat(expr.offset()); - } else { - const IntegerVariableProto& var_proto = - model_proto->variables(expr.vars(0)); - const std::string& var_name = var_proto.name().empty() - ? absl::StrCat("i", expr.vars(0)) - : var_proto.name(); - const int64_t coeff = expr.coeffs(0); - std::string result; - if (coeff == 1) { - result = var_name; - } else if (coeff == -1) { - result = absl::StrCat("-", var_name); - } else if (coeff != 0) { - result = absl::StrCat(coeff, " * ", var_name); - } - if (expr.offset() > 0) { - absl::StrAppend(&result, " + ", expr.offset()); - } else if (expr.offset() < 0) { - absl::StrAppend(&result, " - ", -expr.offset()); - } - return result; - } -} - -std::shared_ptr RebuildFromLinearExpressionProto( - const LinearExpressionProto& proto, - std::shared_ptr model_proto) { - if (proto.vars().empty()) { - return LinearExpr::ConstantInt(proto.offset()); - } else if (proto.vars_size() == 1) { - return LinearExpr::AffineInt( - std::make_shared(model_proto, proto.vars(0)), proto.coeffs(0), - proto.offset()); - } else { - std::vector> vars; - vars.reserve(proto.vars_size()); - for (const int var : proto.vars()) { - vars.push_back(std::make_shared(model_proto, var)); - } - return std::make_shared(vars, proto.coeffs(), - proto.offset()); - } -} - -class IntervalVar { - public: - IntervalVar(std::shared_ptr model_proto, int index) - : model_proto_(model_proto), index_(index) { - DCHECK_GE(index, 0); - } - - int index() const { return index_; } - - std::shared_ptr model_proto() const { return model_proto_; } - - ConstraintProto* proto() const { - return model_proto_->mutable_constraints(index_); - } - - std::string ToString() const { - const std::string name = proto()->name(); - if (name.empty()) { - return absl::StrCat("iv", index_); - } else { - return name; - } - } - - std::string DebugString() const { - if (proto()->enforcement_literal().empty()) { - return absl::StrCat( - name(), "(start = ", - ShortExprName(proto()->interval().start(), model_proto()), - ", size = ", ShortExprName(proto()->interval().size(), model_proto()), - ", end = ", ShortExprName(proto()->interval().end(), model_proto()), - ")"); - } else { - return absl::StrCat( - name(), "(start = ", - ShortExprName(proto()->interval().start(), model_proto()), - ", size = ", ShortExprName(proto()->interval().size(), model_proto()), - ", end = ", ShortExprName(proto()->interval().end(), model_proto()), - ", is_present = ", - ShortName(proto()->enforcement_literal(0), model_proto()), ")"); - } - } - - std::string name() const { return proto()->name(); } - - void SetName(const std::string& name) { proto()->set_name(name); } - - std::shared_ptr StartExpr() const { - return RebuildFromLinearExpressionProto(proto()->interval().start(), - model_proto_); - } - std::shared_ptr SizeExpr() const { - return RebuildFromLinearExpressionProto(proto()->interval().size(), - model_proto_); - } - std::shared_ptr EndExpr() const { - return RebuildFromLinearExpressionProto(proto()->interval().end(), - model_proto_); - } - - std::vector> PresenceLiterals() const { - std::vector> literals; - for (const int lit : proto()->enforcement_literal()) { - if (RefIsPositive(lit)) { - literals.push_back(std::make_shared(model_proto_, lit)); - } else { - literals.push_back(std::make_shared( - model_proto_, NegatedRef(lit))); - } - } - return literals; - } - - private: - std::shared_ptr model_proto_; - int index_; -}; - -std::shared_ptr CpBaseModel::NewIntervalVarInternal( - const std::string& name, const py::handle& start, const py::handle& size, - const py::handle& end, py::sequence literals) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - if (!name.empty()) ct->set_name(name); - LinearExprToProto(start, 1, ct->mutable_interval()->mutable_start()); - LinearExprToProto(size, 1, ct->mutable_interval()->mutable_size()); - LinearExprToProto(end, 1, ct->mutable_interval()->mutable_end()); - for (const auto& lit : literals) { - ct->add_enforcement_literal(GetOrMakeBooleanIndex(lit)); - } - const std::string method = literals.empty() - ? "cp_model.new_interval_var" - : "cp_model.new_optional_interval_var"; - if (ct->interval().start().vars_size() > 1) { - ThrowError(PyExc_TypeError, - absl::StrCat(method, ": start must be affine or constant.")); - } - if (ct->interval().size().vars_size() > 1) { - ThrowError(PyExc_TypeError, - absl::StrCat(method, ": size must be affine or constant.")); - } - if (ct->interval().end().vars_size() > 1) { - ThrowError(PyExc_TypeError, - absl::StrCat(method, ": end must be affine or constant.")); - } - return std::make_shared(model_proto_, ct_index); -} - -std::shared_ptr CpBaseModel::AddNoOverlapInternal( - const std::vector>& intervals) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - ct->mutable_no_overlap()->mutable_intervals()->Reserve(intervals.size()); - for (const std::shared_ptr& interval : intervals) { - ct->mutable_no_overlap()->add_intervals(interval->index()); - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddNoOverlap2DInternal( - const std::vector>& x_intervals, - const std::vector>& y_intervals) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - ct->mutable_no_overlap_2d()->mutable_x_intervals()->Reserve( - x_intervals.size()); - for (const std::shared_ptr& x_interval : x_intervals) { - ct->mutable_no_overlap_2d()->add_x_intervals(x_interval->index()); - } - ct->mutable_no_overlap_2d()->mutable_y_intervals()->Reserve( - y_intervals.size()); - for (const std::shared_ptr& y_interval : y_intervals) { - ct->mutable_no_overlap_2d()->add_y_intervals(y_interval->index()); - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddCumulativeInternal( - const std::vector>& intervals, - const py::sequence demands, const py::handle& capacity) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - CumulativeConstraintProto* proto = ct->mutable_cumulative(); - - proto->mutable_intervals()->Reserve(intervals.size()); - for (const std::shared_ptr& interval : intervals) { - proto->add_intervals(interval->index()); - } - - proto->mutable_demands()->Reserve(demands.size()); - for (const auto& expr : demands) { - LinearExprToProto(expr, 1, proto->add_demands()); - } - - LinearExprToProto(capacity, 1, proto->mutable_capacity()); - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddCircuitInternal( - const std::vector>& arcs) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - CircuitConstraintProto* proto = ct->mutable_circuit(); - proto->mutable_tails()->Reserve(arcs.size()); - proto->mutable_heads()->Reserve(arcs.size()); - proto->mutable_literals()->Reserve(arcs.size()); - for (const auto& [tail, head, lit] : arcs) { - proto->add_tails(tail); - proto->add_heads(head); - proto->add_literals(GetOrMakeBooleanIndex(lit)); - } - return std::make_shared(shared_from_this(), ct_index); -} - -std::shared_ptr CpBaseModel::AddRoutesInternal( - const std::vector>& arcs) { - const int ct_index = model_proto_->constraints_size(); - ConstraintProto* ct = model_proto_->add_constraints(); - RoutesConstraintProto* proto = ct->mutable_routes(); - proto->mutable_tails()->Reserve(arcs.size()); - proto->mutable_heads()->Reserve(arcs.size()); - proto->mutable_literals()->Reserve(arcs.size()); - for (const auto& [tail, head, lit] : arcs) { - proto->add_tails(tail); - proto->add_heads(head); - proto->add_literals(GetOrMakeBooleanIndex(lit)); - } - return std::make_shared(shared_from_this(), ct_index); -} - -#if PY_VERSION_HEX >= 0x030E00A7 && !defined(PYPY_VERSION) -bool was_optimized_in_function_call(PyObject* op) { - PyFrameObject* frame = PyEval_GetFrame(); - if (frame == NULL) { - return false; - } - _PyInterpreterFrame* f = frame->f_frame; - _PyStackRef* base = _PyFrame_Stackbase(f); - _PyStackRef* stackpointer = f->stackpointer; - - while (stackpointer > base) { - stackpointer--; - if (op == PyStackRef_AsPyObjectBorrow(*stackpointer)) { - // We want detect if the object is a temporary and borrowed. If so, it - // should be only referenced once in the stack, but it should not be safe. - return !PyStackRef_IsHeapSafe(*stackpointer); - } - } - return false; -} - -bool IsOnwedExclusivelyThroughPyBind11(PyObject* op) { -#if !defined(Py_GIL_DISABLED) - return Py_REFCNT(op) == 3; -#else - // NOTE: the entire ob_ref_shared field must be zero, including flags, to - // ensure that other threads cannot concurrently create new references to - // this object. - return (_Py_IsOwnedByCurrentThread(op) && - _Py_atomic_load_uint32_relaxed(&op->ob_ref_local) == 3 && - _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared) == 0); -#endif -} - -template -bool IsFree(std::shared_ptr expr) { - PyObject* op = py::cast(expr).ptr(); - return IsOnwedExclusivelyThroughPyBind11(op) && - !was_optimized_in_function_call(op); -} -#else -template -bool IsFree(std::shared_ptr expr) { - return Py_REFCNT(py::cast(expr).ptr()) == 4; -} -#endif - PYBIND11_MODULE(cp_model_helper, m) { py::module::import("ortools.util.python.sorted_interval_list"); - // We keep the CamelCase name for the SolutionCallback class to be - // compatible with the pre PEP8 python code. - py::class_(m, "SolutionCallback") - .def(py::init<>()) - .def("OnSolutionCallback", &SolutionCallback::OnSolutionCallback) - .def("BestObjectiveBound", &SolutionCallback::BestObjectiveBound) - .def("DeterministicTime", &SolutionCallback::DeterministicTime) - .def("HasResponse", &SolutionCallback::HasResponse) - .def("NumBinaryPropagations", &SolutionCallback::NumBinaryPropagations) - .def("NumBooleans", &SolutionCallback::NumBooleans) - .def("NumBranches", &SolutionCallback::NumBranches) - .def("NumConflicts", &SolutionCallback::NumConflicts) - .def("NumIntegerPropagations", &SolutionCallback::NumIntegerPropagations) - .def("ObjectiveValue", &SolutionCallback::ObjectiveValue) - .def("Response", &SolutionCallback::SharedResponse) - .def("SolutionBooleanValue", &SolutionCallback::SolutionBooleanValue, - py::arg("index")) - .def("SolutionIntegerValue", &SolutionCallback::SolutionIntegerValue, - py::arg("index")) - .def("StopSearch", &SolutionCallback::StopSearch) - .def("UserTime", &SolutionCallback::UserTime) - .def("WallTime", &SolutionCallback::WallTime) - .def( - "Value", - [](const SolutionCallback& self, std::shared_ptr expr) { - return ResponseHelper::Value(self.SharedResponse(), expr); - }, - "Returns the value of a linear expression after solve.") - .def( - "Value", [](const SolutionCallback&, int64_t value) { return value; }, - "Returns the value of a linear expression after solve.") - .def( - "FloatValue", - [](const SolutionCallback& self, std::shared_ptr expr) { - return ResponseHelper::FloatValue(self.SharedResponse(), expr); - }, - "Returns the value of a floating point linear expression after " - "solve.") - .def( - "FloatValue", - [](const SolutionCallback&, double value) { return value; }, - "Returns the value of a floating point linear expression after " - "solve.") - .def( - "BooleanValue", - [](const SolutionCallback& self, std::shared_ptr lit) { - return ResponseHelper::BooleanValue(self.SharedResponse(), lit); - }, - "Returns the Boolean value of a literal after solve.") - .def( - "BooleanValue", [](const SolutionCallback&, bool lit) { return lit; }, - "Returns the Boolean value of a literal after solve."); - - py::class_(m, "ResponseHelper") - .def_static("boolean_value", &ResponseHelper::BooleanValue, - py::arg("response").none(false), py::arg("lit").none(false)) - .def_static("boolean_value", &ResponseHelper::FixedBooleanValue, - py::arg("response").none(false), py::arg("lit").none(false)) - .def_static("float_value", &ResponseHelper::FloatValue, - py::arg("response").none(false), py::arg("expr").none(false)) - .def_static("float_value", &ResponseHelper::FixedFloatValue, - py::arg("response").none(false), py::arg("value").none(false)) - .def_static("sufficient_assumptions_for_infeasibility", - &ResponseHelper::SufficientAssumptionsForInfeasibility, - py::arg("response").none(false)) - .def_static("value", &ResponseHelper::Value, - py::arg("response").none(false), py::arg("expr").none(false)) - .def_static("value", &ResponseHelper::FixedValue, - py::arg("response").none(false), - py::arg("value").none(false)); - - py::class_(m, "SolveWrapper") - .def(py::init<>()) - .def( - "add_log_callback", - [](ExtSolveWrapper* solve_wrapper, - std::function log_callback) { - std::function safe_log_callback = - [solve_wrapper, log_callback](std::string message) -> void { - ::py::gil_scoped_acquire acquire; - try { - log_callback(message); - } catch (py::error_already_set& e) { - // We assume this code is serialized as the gil is held. - if (!solve_wrapper->local_error_already_set_.has_value()) { - solve_wrapper->local_error_already_set_ = e; - } - solve_wrapper->StopSearch(); - } - }; - solve_wrapper->AddLogCallback(safe_log_callback); - }, - py::arg("log_callback").none(false)) - .def("add_solution_callback", &SolveWrapper::AddSolutionCallback, - py::arg("callback")) - .def("clear_solution_callback", &SolveWrapper::ClearSolutionCallback) - .def( - "add_best_bound_callback", - [](ExtSolveWrapper* solve_wrapper, - std::function best_bound_callback) { - std::function safe_best_bound_callback = - [solve_wrapper, best_bound_callback](double bound) -> void { - ::py::gil_scoped_acquire acquire; - try { - best_bound_callback(bound); - } catch (py::error_already_set& e) { - // We assume this code is serialized as the gil is held. - if (!solve_wrapper->local_error_already_set_.has_value()) { - solve_wrapper->local_error_already_set_ = e; - } - solve_wrapper->StopSearch(); - } - }; - solve_wrapper->AddBestBoundCallback(safe_best_bound_callback); - }, - py::arg("best_bound_callback").none(false)) - .def( - "set_parameters", - [](ExtSolveWrapper* solve_wrapper, - std::shared_ptr parameters) { - solve_wrapper->SetParameters(*parameters); - }, - py::arg("parameters").none(false)) - .def( - "solve", - [](ExtSolveWrapper* solve_wrapper, - std::shared_ptr model_proto) -> CpSolverResponse { - const auto result = [=]() -> CpSolverResponse { - ::py::gil_scoped_release release; - return solve_wrapper->Solve(*model_proto); - }(); - if (solve_wrapper->local_error_already_set_.has_value()) { - solve_wrapper->local_error_already_set_->restore(); - solve_wrapper->local_error_already_set_.reset(); - throw py::error_already_set(); - } - return result; - }, - py::arg("model_proto").none(false)) - .def("stop_search", &SolveWrapper::StopSearch); - - py::class_(m, "CpSatHelper") - .def_static("model_stats", &CpSatHelper::ModelStats, - py::arg("model_proto")) - .def_static("solver_response_stats", &CpSatHelper::SolverResponseStats, - py::arg("response")) - .def_static("validate_model", &CpSatHelper::ValidateModel, - py::arg("model_proto")) - .def_static("write_model_to_file", &CpSatHelper::WriteModelToFile, - py::arg("model_proto"), py::arg("filename")); - - py::class_>( - m, "LinearExpr", DOC(operations_research, sat, python, LinearExpr)) - .def_static("sum", &SumArguments, "Returns the sum(expressions).") - .def_static("weighted_sum", &WeightedSumArguments, py::arg("expressions"), - py::arg("coefficients"), - "Returns the sum of (expressions[i] * coefficients[i])") - .def_static("term", &LinearExpr::TermInt, py::arg("expr").none(false), - py::arg("coeff"), - DOC(operations_research, sat, python, LinearExpr, TermInt)) - .def_static("term", &LinearExpr::TermFloat, py::arg("expr").none(false), - py::arg("coeff"), - DOC(operations_research, sat, python, LinearExpr, TermFloat)) - .def_static("affine", &LinearExpr::AffineInt, py::arg("expr").none(false), - py::arg("coeff"), py::arg("offset"), - DOC(operations_research, sat, python, LinearExpr, AffineInt)) - .def_static( - "affine", &LinearExpr::AffineFloat, py::arg("expr").none(false), - py::arg("coeff"), py::arg("offset"), - DOC(operations_research, sat, python, LinearExpr, AffineFloat)) - .def_static( - "constant", &LinearExpr::ConstantInt, py::arg("value"), - DOC(operations_research, sat, python, LinearExpr, ConstantInt)) - .def_static( - "constant", &LinearExpr::ConstantFloat, py::arg("value"), - DOC(operations_research, sat, python, LinearExpr, ConstantFloat)) - // Pre PEP8 compatibility layer. - .def_static("Sum", &SumArguments) - .def_static("WeightedSum", &WeightedSumArguments, py::arg("expressions"), - py::arg("coefficients")) - .def_static("Term", &LinearExpr::TermInt, py::arg("expr").none(false), - py::arg("coeff"), "Returns expr * coeff.") - .def_static("Term", &LinearExpr::TermFloat, py::arg("expr").none(false), - py::arg("coeff"), "Returns expr * coeff.") - // Methods. - .def("__str__", - [](std::shared_ptr expr) -> std::string { - return expr->ToString(); - }) - .def("__repr__", - [](std::shared_ptr expr) -> std::string { - return expr->DebugString(); - }) - .def( - "is_integer", - [](std::shared_ptr expr) -> bool { - return expr->IsInteger(); - }, - DOC(operations_research, sat, python, LinearExpr, IsInteger)) - // Operators. - .def("__add__", &LinearExpr::Add, py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, Add)) - .def("__add__", &LinearExpr::AddInt, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, AddInt)) - .def("__add__", &LinearExpr::AddFloat, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, AddFloat)) - .def("__radd__", &LinearExpr::AddInt, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, AddInt)) - .def("__radd__", &LinearExpr::AddFloat, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, AddFloat)) - .def("__sub__", &LinearExpr::Sub, py::arg("h").none(false), - DOC(operations_research, sat, python, LinearExpr, Sub)) - .def("__sub__", &LinearExpr::SubInt, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, SubInt)) - .def("__sub__", &LinearExpr::SubFloat, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, SubFloat)) - .def("__rsub__", &LinearExpr::RSub, py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, RSub)) - .def("__rsub__", &LinearExpr::RSubInt, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, RSubInt)) - .def("__rsub__", &LinearExpr::RSubFloat, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, RSubFloat)) - .def("__mul__", &LinearExpr::MulInt, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, MulInt)) - .def("__mul__", &LinearExpr::MulFloat, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, MulFloat)) - .def("__rmul__", &LinearExpr::MulInt, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, MulInt)) - .def("__rmul__", &LinearExpr::MulFloat, py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, MulFloat)) - .def("__neg__", &LinearExpr::Neg, - DOC(operations_research, sat, python, LinearExpr, Neg)) - .def( - "__eq__", - [](std::shared_ptr lhs, std::shared_ptr rhs) { - RaiseIfNone(rhs); - return CheckBoundedLinearExpression(lhs->Eq(rhs), lhs, rhs); - }, - DOC(operations_research, sat, python, LinearExpr, Eq)) - .def( - "__eq__", - [](std::shared_ptr lhs, int64_t rhs) { - if (rhs == std::numeric_limits::max() || - rhs == std::numeric_limits::min()) { - ThrowError(PyExc_ValueError, - "== INT_MIN or INT_MAX is not supported"); - } - return CheckBoundedLinearExpression(lhs->EqCst(rhs), lhs); - }, - DOC(operations_research, sat, python, LinearExpr, EqCst)) - .def( - "__ne__", - [](std::shared_ptr lhs, std::shared_ptr rhs) { - RaiseIfNone(rhs); - return CheckBoundedLinearExpression(lhs->Ne(rhs), lhs, rhs); - }, - DOC(operations_research, sat, python, LinearExpr, Ne)) - .def( - "__ne__", - [](std::shared_ptr lhs, int64_t rhs) { - return CheckBoundedLinearExpression(lhs->NeCst(rhs), lhs); - }, - DOC(operations_research, sat, python, LinearExpr, NeCst)) - .def( - "__le__", - [](std::shared_ptr lhs, std::shared_ptr rhs) { - RaiseIfNone(rhs); - return CheckBoundedLinearExpression(lhs->Le(rhs), lhs, rhs); - }, - DOC(operations_research, sat, python, LinearExpr, Le)) - .def( - "__le__", - [](std::shared_ptr lhs, int64_t rhs) { - if (rhs == std::numeric_limits::min()) { - ThrowError(PyExc_ArithmeticError, "<= INT_MIN is not supported"); - } - return CheckBoundedLinearExpression(lhs->LeCst(rhs), lhs); - }, - DOC(operations_research, sat, python, LinearExpr, LeCst)) - .def( - "__lt__", - [](std::shared_ptr lhs, std::shared_ptr rhs) { - RaiseIfNone(rhs); - return CheckBoundedLinearExpression(lhs->Lt(rhs), lhs, rhs); - }, - DOC(operations_research, sat, python, LinearExpr, Lt)) - .def( - "__lt__", - [](std::shared_ptr lhs, int64_t rhs) { - if (rhs == std::numeric_limits::min()) { - ThrowError(PyExc_ArithmeticError, "< INT_MIN is not supported"); - } - return CheckBoundedLinearExpression(lhs->LtCst(rhs), lhs); - }, - DOC(operations_research, sat, python, LinearExpr, LtCst)) - .def( - "__ge__", - [](std::shared_ptr lhs, std::shared_ptr rhs) { - RaiseIfNone(rhs); - return CheckBoundedLinearExpression(lhs->Ge(rhs), lhs, rhs); - }, - DOC(operations_research, sat, python, LinearExpr, Ge)) - .def( - "__ge__", - [](std::shared_ptr lhs, int64_t rhs) { - if (rhs == std::numeric_limits::max()) { - ThrowError(PyExc_ArithmeticError, ">= INT_MAX is not supported"); - } - return CheckBoundedLinearExpression(lhs->GeCst(rhs), lhs); - }, - DOC(operations_research, sat, python, LinearExpr, GeCst)) - .def( - "__gt__", - [](std::shared_ptr lhs, std::shared_ptr rhs) { - RaiseIfNone(rhs); - return CheckBoundedLinearExpression(lhs->Gt(rhs), lhs, rhs); - }, - DOC(operations_research, sat, python, LinearExpr, Gt)) - .def( - "__gt__", - [](std::shared_ptr lhs, int64_t rhs) { - if (rhs == std::numeric_limits::max()) { - ThrowError(PyExc_ArithmeticError, "> INT_MAX is not supported"); - } - return CheckBoundedLinearExpression(lhs->GtCst(rhs), lhs); - }, - DOC(operations_research, sat, python, LinearExpr, GtCst)) - // Disable other operators as they are not supported. - .def("__div__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling / on a linear expression is not supported, " - "please use CpModel.add_division_equality"); - }) - .def("__truediv__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling // on a linear expression is not supported, " - "please use CpModel.add_division_equality"); - }) - .def("__mod__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling %% on a linear expression is not supported, " - "please use CpModel.add_modulo_equality"); - }) - .def("__pow__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling ** on a linear expression is not supported, " - "please use CpModel.add_multiplication_equality"); - }) - .def("__lshift__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling left shift on a linear expression is not " - "supported"); - }) - .def("__rshift__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling right shift on a linear expression is " - "not supported"); - }) - .def("__and__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling and on a linear expression is not supported"); - }) - .def("__or__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling or on a linear expression is not supported"); - }) - .def("__xor__", - [](std::shared_ptr /*self*/, py::handle /*other*/) { - ThrowError(PyExc_NotImplementedError, - "calling xor on a linear expression is not supported"); - }) - .def("__abs__", - [](std::shared_ptr /*self*/) { - ThrowError( - PyExc_NotImplementedError, - "calling abs() on a linear expression is not supported, " - "please use CpModel.add_abs_equality"); - }) - .def("__bool__", [](std::shared_ptr /*self*/) { - ThrowError(PyExc_NotImplementedError, - "Evaluating a LinearExpr instance as a Boolean is " - "not supported."); - }); - - // Expose Internal classes, mostly for testing. - py::class_, LinearExpr>( - m, "FlatFloatExpr", DOC(operations_research, sat, python, FlatFloatExpr)) - .def(py::init>()) - .def_property_readonly("vars", &FlatFloatExpr::vars) - .def_property_readonly("coeffs", &FlatFloatExpr::coeffs) - .def_property_readonly("offset", &FlatFloatExpr::offset); - - py::class_, LinearExpr>( - m, "FlatIntExpr", DOC(operations_research, sat, python, FlatIntExpr)) - .def(py::init([](std::shared_ptr expr) { - FlatIntExpr* result = new FlatIntExpr(expr); - if (!result->ok()) { - ThrowError(PyExc_TypeError, - absl::StrCat("Tried to build a FlatIntExpr from a linear " - "expression with " - "floating point coefficients or constants: ", - expr->DebugString())); - } - return result; - })) - .def_property_readonly("vars", &FlatIntExpr::vars) - .def_property_readonly("coeffs", &FlatIntExpr::coeffs) - .def_property_readonly("offset", &FlatIntExpr::offset) - .def_property_readonly("ok", &FlatIntExpr::ok); - - py::class_, LinearExpr>( - m, "SumArray", DOC(operations_research, sat, python, SumArray)) - .def( - py::init>, int64_t, double>()) - .def( - "__add__", - [](std::shared_ptr expr, - std::shared_ptr other) -> std::shared_ptr { - return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, Add)) - .def( - "__add__", - [](std::shared_ptr expr, - int64_t cst) -> std::shared_ptr { - return IsFree(expr) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); - }, - DOC(operations_research, sat, python, LinearExpr, AddInt)) - .def( - "__add__", - [](std::shared_ptr expr, - double cst) -> std::shared_ptr { - return IsFree(expr) ? expr->AddFloatInPlace(cst) - : expr->AddFloat(cst); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, AddFloat)) - .def( - "__radd__", - [](std::shared_ptr expr, - std::shared_ptr other) -> std::shared_ptr { - return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, Add)) - .def( - "__radd__", - [](std::shared_ptr expr, - int64_t cst) -> std::shared_ptr { - return IsFree(expr) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); - }, - py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, AddInt)) - .def( - "__radd__", - [](std::shared_ptr expr, - double cst) -> std::shared_ptr { - return IsFree(expr) ? expr->AddFloatInPlace(cst) - : expr->AddFloat(cst); - }, - py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, AddFloat)) - .def( - "__iadd__", - [](std::shared_ptr expr, - std::shared_ptr other) -> std::shared_ptr { - return expr->AddInPlace(other); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, Add)) - .def( - "__iadd__", - [](std::shared_ptr expr, - int64_t cst) -> std::shared_ptr { - return expr->AddIntInPlace(cst); - }, - DOC(operations_research, sat, python, LinearExpr, AddInt)) - .def( - "__iadd__", - [](std::shared_ptr expr, - double cst) -> std::shared_ptr { - return expr->AddFloatInPlace(cst); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, AddFloat)) - .def( - "__sub__", - [](std::shared_ptr expr, - std::shared_ptr other) -> std::shared_ptr { - return IsFree(expr) ? expr->AddInPlace(other->Neg()) - : expr->Sub(other); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, Sub)) - .def( - "__sub__", - [](std::shared_ptr expr, - int64_t cst) -> std::shared_ptr { - return IsFree(expr) ? expr->AddIntInPlace(-cst) : expr->SubInt(cst); - }, - py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, SubInt)) - .def( - "__sub__", - [](std::shared_ptr expr, - double cst) -> std::shared_ptr { - return IsFree(expr) ? expr->AddFloatInPlace(-cst) - : expr->SubFloat(cst); - }, - py::arg("cst"), - DOC(operations_research, sat, python, LinearExpr, SubFloat)) - .def( - "__isub__", - [](std::shared_ptr expr, - std::shared_ptr other) -> std::shared_ptr { - return expr->AddInPlace(other->Neg()); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, Sub)) - .def( - "__isub__", - [](std::shared_ptr expr, - int64_t cst) -> std::shared_ptr { - return expr->AddIntInPlace(-cst); - }, - DOC(operations_research, sat, python, LinearExpr, SubInt)) - .def( - "__isub__", - [](std::shared_ptr expr, - double cst) -> std::shared_ptr { - return expr->AddFloatInPlace(-cst); - }, - py::arg("other").none(false), - DOC(operations_research, sat, python, LinearExpr, SubFloat)) - .def_property_readonly("num_exprs", &SumArray::num_exprs) - .def_property_readonly("int_offset", &SumArray::int_offset) - .def_property_readonly("double_offset", &SumArray::double_offset); - - py::class_, LinearExpr>( - m, "FloatAffine", DOC(operations_research, sat, python, FloatAffine)) - .def(py::init, double, double>()) - .def_property_readonly("expression", &FloatAffine::expression) - .def_property_readonly("coefficient", &FloatAffine::coefficient) - .def_property_readonly("offset", &FloatAffine::offset); - - py::class_, LinearExpr>( - m, "IntAffine", DOC(operations_research, sat, python, IntAffine)) - .def(py::init, int64_t, int64_t>()) - .def_property_readonly("expression", &IntAffine::expression, - "Returns the linear expression.") - .def_property_readonly("coefficient", &IntAffine::coefficient, - "Returns the coefficient.") - .def_property_readonly("offset", &IntAffine::offset, - "Returns the offset."); - - py::class_, LinearExpr>( - m, "Literal", DOC(operations_research, sat, python, Literal)) - .def_property_readonly( - "index", &Literal::index, - DOC(operations_research, sat, python, Literal, index)) - .def("negated", &Literal::negated, - DOC(operations_research, sat, python, Literal, negated)) - .def("__invert__", &Literal::negated, - DOC(operations_research, sat, python, Literal, negated)) - .def("__bool__", - [](std::shared_ptr /*self*/) { - ThrowError(PyExc_NotImplementedError, - "Evaluating a Literal as a Boolean value is " - "not supported."); - }) - .def("__hash__", &Literal::Hash) - // Pre PEP8 compatibility layer. - .def("Not", &Literal::negated) - .def("Index", &Literal::index); - - // IntVar and NotBooleanVariable both hold a shared_ptr to the model_proto. - py::class_, Literal>( - m, "IntVar", DOC(operations_research, sat, python, IntVar)) - .def(py::init, int>()) - .def(py::init>()) // new variable. - .def_property_readonly( - "proto", &IntVar::proto, py::return_value_policy::reference_internal, - "Returns the IntegerVariableProto of this variable.") - .def_property_readonly("model_proto", &IntVar::model_proto, - "Returns the CP model protobuf") - .def_property_readonly( - "index", &IntVar::index, py::return_value_policy::reference, - DOC(operations_research, sat, python, IntVar, index)) - .def_property_readonly( - "is_boolean", &IntVar::is_boolean, - DOC(operations_research, sat, python, IntVar, is_boolean)) - .def_property("name", &IntVar::name, &IntVar::SetName, - "The name of the variable.") - .def( - "with_name", - [](std::shared_ptr self, const std::string& name) { - self->SetName(name); - return self; - }, - py::arg("name"), - "Sets the name of the variable and returns the variable.") - .def_property("domain", &IntVar::domain, &IntVar::SetDomain, - "The domain of the variable.") - .def( - "with_domain", - [](std::shared_ptr self, const Domain& domain) { - self->SetDomain(domain); - return self; - }, - py::arg("domain"), - "Sets the domain of the variable and returns the variable.") - .def("__str__", &IntVar::ToString) - .def("__repr__", &IntVar::DebugString) - .def( - "negated", - [](std::shared_ptr self) { - if (!self->is_boolean()) { - ThrowError(PyExc_TypeError, - "negated() is only supported for Boolean variables."); - } - return self->negated(); - }, - DOC(operations_research, sat, python, IntVar, negated)) - .def( - "__invert__", - [](std::shared_ptr self) { - if (!self->is_boolean()) { - ThrowError(PyExc_TypeError, - "negated() is only supported for Boolean variables."); - } - return self->negated(); - }, - DOC(operations_research, sat, python, IntVar, negated)) - .def("__copy__", - [](const std::shared_ptr& self) { - return std::make_shared(self->model_proto(), - self->index()); - }) - .def(py::pickle( - [](std::shared_ptr p) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(p->model_proto(), p->index()); - }, - [](py::tuple t) { // __setstate__ - if (t.size() != 2) throw std::runtime_error("Invalid state!"); - - return std::make_shared( - t[0].cast>(), t[1].cast()); - })) - // Pre PEP8 compatibility layer. - .def("Name", &IntVar::name) - .def("Proto", &IntVar::proto, py::return_value_policy::reference, - py::keep_alive<1, 0>(), - "Returns the IntegerVariableProto of this variable.") - .def("Not", - [](std::shared_ptr self) { - if (!self->is_boolean()) { - ThrowError(PyExc_TypeError, - "negated() is only supported for Boolean variables."); - } - return self->negated(); - }) - .def("Index", &IntVar::index); - - py::class_, Literal>( - m, "NotBooleanVariable", - DOC(operations_research, sat, python, NotBooleanVariable)) - .def_property_readonly( - "index", - [](std::shared_ptr not_var) -> int { - return not_var->index(); - }, - DOC(operations_research, sat, python, NotBooleanVariable, index)) - .def("__str__", - [](std::shared_ptr not_var) -> std::string { - return not_var->ToString(); - }) - .def("__repr__", - [](std::shared_ptr not_var) -> std::string { - return not_var->DebugString(); - }) - .def( - "negated", - [](std::shared_ptr not_var) - -> std::shared_ptr { return not_var->negated(); }, - DOC(operations_research, sat, python, NotBooleanVariable, negated)) - .def( - "__invert__", - [](std::shared_ptr not_var) - -> std::shared_ptr { return not_var->negated(); }, - DOC(operations_research, sat, python, NotBooleanVariable, negated)) - // Pre PEP8 compatibility layer. - .def( - "Not", - [](std::shared_ptr not_var) - -> std::shared_ptr { return not_var->negated(); }, - DOC(operations_research, sat, python, NotBooleanVariable, negated)); - - py::class_>( - m, "BoundedLinearExpression", - DOC(operations_research, sat, python, BoundedLinearExpression)) - .def(py::init, Domain>()) - .def(py::init, - const std::shared_ptr, Domain>()) - .def_property_readonly("bounds", &BoundedLinearExpression::bounds) - .def_property_readonly("vars", &BoundedLinearExpression::vars) - .def_property_readonly("coeffs", &BoundedLinearExpression::coeffs) - .def_property_readonly("offset", &BoundedLinearExpression::offset) - .def_property_readonly("ok", &BoundedLinearExpression::ok) - .def("__str__", &BoundedLinearExpression::ToString) - .def("__repr__", &BoundedLinearExpression::DebugString) - .def("__bool__", [](const BoundedLinearExpression& self) { - bool result; - if (self.CastToBool(&result)) return result; - ThrowError(PyExc_NotImplementedError, - absl::StrCat("Evaluating a BoundedLinearExpression '", - self.ToString(), - "'instance as a Boolean is " - "not supported.")); - return false; - }); - - py::enum_(m, "BoolArgumentConstraint") - .value("at_most_one", BoolArgumentConstraint::kAtMostOne) - .value("bool_and", BoolArgumentConstraint::kBoolAnd) - .value("bool_or", BoolArgumentConstraint::kBoolOr) - .value("bool_xor", BoolArgumentConstraint::kBoolXor) - .value("exactly_one", BoolArgumentConstraint::kExactlyOne) - .export_values(); - - py::enum_(m, "LinearArgumentConstraint") - .value("div", LinearArgumentConstraint::kDiv) - .value("max", LinearArgumentConstraint::kMax) - .value("min", LinearArgumentConstraint::kMin) - .value("mod", LinearArgumentConstraint::kMod) - .value("prod", LinearArgumentConstraint::kProd) - .export_values(); - - py::class_>( - m, "CpBaseModel", "Base class for the CP model.") - .def(py::init>()) - .def_property_readonly("model_proto", &CpBaseModel::model_proto, - "Returns the CP model protobuf") - .def("get_or_make_index_from_constant", - &CpBaseModel::GetOrMakeIndexFromConstant, py::arg("value"), - "Returns the index of the given constant value.") - .def("get_or_make_boolean_index", &CpBaseModel::GetOrMakeBooleanIndex, - py::arg("value"), "Returns the index of the given boolean value.") - .def("get_or_make_variable_index", &CpBaseModel::GetOrMakeVariableIndex, - py::arg("arg"), - "Returns the index of the given variable or constant variable.") - .def("is_boolean_value", &CpBaseModel::IsBooleanValue, py::arg("value")) - .def("rebuild_constant_map", &CpBaseModel::RebuildConstantMap) - .def("_add_all_different", &CpBaseModel::AddAllDifferentInternal) - .def("_add_automaton", &CpBaseModel::AddAutomatonInternal, - py::arg("transition_expressions"), py::arg("starting_state"), - py::arg("final_states"), py::arg("transition_triples")) - .def("_add_bool_argument_constraint", - &CpBaseModel::AddBoolArgumentConstraintInternal, py::arg("name")) - .def("_add_bounded_linear_expression", - &CpBaseModel::AddBoundedLinearExpressionInternal, py::arg("ble")) - .def("_add_element", &CpBaseModel::AddElementInternal, - py::arg("index").none(false), py::arg("expressions"), - py::arg("target").none(false)) - .def("_add_linear_argument_constraint", - &CpBaseModel::AddLinearArgumentConstraintInternal, - py::arg("name").none(false), py::arg("target").none(false)) - .def("_add_inverse", &CpBaseModel::AddInverseInternal, py::arg("direct"), - py::arg("inverse")) - .def("_add_reservoir", &CpBaseModel::AddReservoirInternal, - py::arg("times"), py::arg("level_changes"), py::arg("actives"), - py::arg("min_level"), py::arg("max_level")) - .def("_add_table", &CpBaseModel::AddTableInternal, py::arg("expressions"), - py::arg("values"), py::arg("negated")) - // Scheduling support. - .def("_new_interval_var", &CpBaseModel::NewIntervalVarInternal, - py::arg("name"), py::arg("start"), py::arg("size"), py::arg("end"), - py::arg("Literals")) - .def("_add_no_overlap", &CpBaseModel::AddNoOverlapInternal, - py::arg("intervals")) - .def("_add_no_overlap_2d", &CpBaseModel::AddNoOverlap2DInternal, - py::arg("x_intervals"), py::arg("y_intervals")) - .def("_add_cumulative", &CpBaseModel::AddCumulativeInternal, - py::arg("intervals"), py::arg("demands"), py::arg("capacity")) - // Routing support. - .def("_add_circuit", &CpBaseModel::AddCircuitInternal, py::arg("arcs")) - .def("_add_routes", &CpBaseModel::AddRoutesInternal, py::arg("arcs")); - - static const char* kConstraintDoc = R"doc( - Base class for constraints. - - Constraints are built by the CpModel through the add methods. - Once created by the CpModel class, they are automatically added to the model. - The purpose of this class is to allow specification of enforcement literals - for this constraint. - - b = model.new_bool_var('b') - x = model.new_int_var(0, 10, 'x') - y = model.new_int_var(0, 10, 'y') - - model.add(x + 2 * y == 5).only_enforce_if(b.negated()) - )doc"; - - static const char* kConstraintOnlyEnforceIfDoc = R"doc( - Adds one or more enforcement literals to the constraint. - - This method adds one or more literals (that is, a boolean variable or its - negation) as enforcement literals. The conjunction of all these literals - determines whether the constraint is active or not. It acts as an - implication, so if the conjunction is true, it implies that the constraint - must be enforced. If it is false, then the constraint is ignored. - - BoolOr, BoolAnd, and linear constraints all support enforcement literals. - - Args: - *literals: One or more Boolean literals. - - Returns: - self.)doc"; - - py::class_>(m, "Constraint", - kConstraintDoc) - .def(py::init, int>()) - .def_property_readonly( - "index", &Constraint::index, - "Returns the index of the constraint in the model protobuf.") - .def_property_readonly("model_proto", &Constraint::model_proto, - "Returns the model protobuf.") - .def_property_readonly("proto", &Constraint::proto, - py::return_value_policy::reference_internal, - "Returns the ConstraintProto of this constraint.") - .def_property("name", &Constraint::name, &Constraint::SetName, - "The name of the constraint.") - .def( - "with_name", - [](Constraint* self, const std::string& name) { - if (name.empty()) { - self->ClearName(); - } else { - self->SetName(name); - } - return self; - }, - "Sets the name of the constraint and returns the constraints") - .def( - "only_enforce_if", - [](std::shared_ptr self, - std::shared_ptr literal) { - self->proto()->add_enforcement_literal(literal->index()); - return self; - }, - py::arg("literal"), kConstraintOnlyEnforceIfDoc) - .def( - "only_enforce_if", - [](std::shared_ptr self, bool literal) { - self->proto()->add_enforcement_literal( - self->model()->GetOrMakeIndexFromConstant(literal)); - return self; - }, - py::arg("literal"), kConstraintOnlyEnforceIfDoc) - .def( - "only_enforce_if", - [](std::shared_ptr self, - const std::vector>& literals) { - for (const std::shared_ptr& literal : literals) { - self->proto()->add_enforcement_literal(literal->index()); - } - }, - py::arg("literals"), kConstraintOnlyEnforceIfDoc) - .def( - "only_enforce_if", - [](std::shared_ptr self, py::args literals) { - if (literals.size() == 1 && - py::isinstance(literals[0])) { - py::sequence seq = literals[0].cast(); - for (const auto& literal : seq) { - self->proto()->add_enforcement_literal( - self->model()->GetOrMakeBooleanIndex(literal)); - } - } else { - for (const auto& literal : literals) { - self->proto()->add_enforcement_literal( - self->model()->GetOrMakeBooleanIndex(literal)); - } - } - }, - kConstraintOnlyEnforceIfDoc) - // Pre PEP8 compatibility. - .def("Name", &Constraint::name) - .def("Index", &Constraint::index) - .def("Proto", &Constraint::proto) - .def("WithName", - [](Constraint* self, const std::string& name) { - if (name.empty()) { - self->ClearName(); - } else { - self->SetName(name); - } - return self; - }) - .def("OnlyEnforceIf", [](std::shared_ptr self, - py::args literals) { - if (literals.size() == 1 && py::isinstance(literals[0])) { - py::sequence seq = literals[0].cast(); - for (const auto& literal : seq) { - self->proto()->add_enforcement_literal( - self->model()->GetOrMakeBooleanIndex(literal)); - } - } else { - for (const auto& literal : literals) { - self->proto()->add_enforcement_literal( - self->model()->GetOrMakeBooleanIndex(literal)); - } - } - }); - - static const char* kIntervalVarDoc = R"doc( -Represents an Interval variable. - -An interval variable is both a constraint and a variable. It is defined by -three integer variables: start, size, and end. - -It is a constraint because, internally, it enforces that start + size == end. - -It is also a variable as it can appear in specific scheduling constraints: -NoOverlap, NoOverlap2D, Cumulative. - -Optionally, an enforcement literal can be added to this constraint, in which -case these scheduling constraints will ignore interval variables with -enforcement literals assigned to false. Conversely, these constraints will -also set these enforcement literals to false if they cannot fit these -intervals into the schedule. - -Raises: - ValueError: if start, size, end are not defined, or have the wrong type. -)doc"; - - py::class_>(m, "IntervalVar", - kIntervalVarDoc) - .def(py::init, int>()) - .def_property_readonly("index", &IntervalVar::index, - "Returns the index of the interval variable.") - .def_property_readonly("model_proto", &IntervalVar::model_proto, - "Returns the model protobuf.") - .def_property_readonly("proto", &IntervalVar::proto, - py::return_value_policy::reference_internal, - "Returns the interval constraint protobuf.") - .def_property("name", &IntervalVar::name, &IntervalVar::SetName, - "The name of the interval variable.") - .def( - "start_expr", - [](std::shared_ptr self) -> py::object { - const IntervalConstraintProto& proto = self->proto()->interval(); - if (proto.start().vars().empty()) { - return py::cast(proto.start().offset()); - } else { - return py::cast(self->StartExpr()); - } - }, - "Returns the start expression of the interval variable.") - .def( - "size_expr", - [](std::shared_ptr self) -> py::object { - const IntervalConstraintProto& proto = self->proto()->interval(); - if (proto.size().vars().empty()) { - return py::cast(proto.size().offset()); - } else { - return py::cast(self->SizeExpr()); - } - }, - "Returns the size expression of the interval variable.") - .def( - "end_expr", - [](std::shared_ptr self) -> py::object { - const IntervalConstraintProto& proto = self->proto()->interval(); - if (proto.end().vars().empty()) { - return py::cast(proto.end().offset()); - } else { - return py::cast(self->EndExpr()); - } - }, - "Returns the end expression of the interval variable.") - .def("presence_literals", &IntervalVar::PresenceLiterals, - "Returns the list of enforcement literals of the interval variable.") - .def("__str__", &IntervalVar::ToString) - .def("__repr__", &IntervalVar::DebugString) - .def(py::pickle( - [](std::shared_ptr p) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(p->model_proto(), p->index()); - }, - [](py::tuple t) { // __setstate__ - if (t.size() != 2) throw std::runtime_error("Invalid state!"); - - return std::make_shared( - t[0].cast>(), t[1].cast()); - })) - // Pre PEP8 compatibility layer. - .def("Proto", &IntervalVar::proto) - .def("Index", &IntervalVar::index) - .def("Name", &IntervalVar::name) - .def("StartExpr", - [](std::shared_ptr self) -> py::object { - const IntervalConstraintProto& proto = self->proto()->interval(); - if (proto.start().vars().empty()) { - return py::cast(proto.start().offset()); - } else { - return py::cast(self->StartExpr()); - } - }) - .def("SizeExpr", - [](std::shared_ptr self) -> py::object { - const IntervalConstraintProto& proto = self->proto()->interval(); - if (proto.size().vars().empty()) { - return py::cast(proto.size().offset()); - } else { - return py::cast(self->SizeExpr()); - } - }) - .def("EndExpr", [](std::shared_ptr self) -> py::object { - const IntervalConstraintProto& proto = self->proto()->interval(); - if (proto.end().vars().empty()) { - return py::cast(proto.end().offset()); - } else { - return py::cast(self->EndExpr()); - } - }); - - m.def( - "rebuild_from_linear_expression_proto", - [](const LinearExpressionProto& proto, - std::shared_ptr model_proto) -> py::object { - if (proto.vars().empty()) { - return py::cast(proto.offset()); - } else { - return py::cast(RebuildFromLinearExpressionProto(proto, model_proto)); - } - }, - py::arg("proto"), py::arg("model_proto")); + DefinePybindWrapperForLinearExpr(m); + DefinePybindWrapperForConstraints(m); + DefinePybindWrapperForSolver(m); #define IMPORT_PROTO_WRAPPER_CODE #include "ortools/sat/python/proto_builder_pybind11.h" diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index f11c39941b6..8e22c6435fa 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -1419,6 +1419,14 @@ def test_rebuild_from_linear_expression_proto(self) -> None: self.assertEqual(canonical_expr2.coeffs[1], 2) self.assertEqual(canonical_expr2.offset, 2) + def test_prettyprint_model_proto(self) -> None: + model = cp_model.CpModel() + model.new_int_var(0, 4, "x") + self.assertEqual( + cmh.prettyprint_model_proto(model.proto), + 'variables { name: "x" domain: [0, 4] }\n', + ) + def test_absent_interval(self) -> None: model = cp_model.CpModel() i = model.new_optional_interval_var(1, 0, 1, False, "") diff --git a/ortools/sat/python/linear_expr.cc b/ortools/sat/python/linear_expr.cc index 61f899437e2..eaadc6a7e9d 100644 --- a/ortools/sat/python/linear_expr.cc +++ b/ortools/sat/python/linear_expr.cc @@ -27,6 +27,7 @@ #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/util/fp_roundtrip_conv.h" @@ -698,46 +699,46 @@ std::shared_ptr LinearExpr::NeCst(int64_t rhs) { std::shared_ptr LinearExpr::Le( std::shared_ptr rhs) { - return std::make_shared( - shared_from_this(), rhs, Domain(std::numeric_limits::min(), 0)); + return std::make_shared(shared_from_this(), rhs, + Domain(kint64min, 0)); } std::shared_ptr LinearExpr::LeCst(int64_t rhs) { - return std::make_shared( - shared_from_this(), Domain(std::numeric_limits::min(), rhs)); + return std::make_shared(shared_from_this(), + Domain(kint64min, rhs)); } std::shared_ptr LinearExpr::Lt( std::shared_ptr rhs) { - return std::make_shared( - shared_from_this(), rhs, Domain(std::numeric_limits::min(), -1)); + return std::make_shared(shared_from_this(), rhs, + Domain(kint64min, -1)); } std::shared_ptr LinearExpr::LtCst(int64_t rhs) { - return std::make_shared( - shared_from_this(), Domain(std::numeric_limits::min(), rhs - 1)); + return std::make_shared(shared_from_this(), + Domain(kint64min, rhs - 1)); } std::shared_ptr LinearExpr::Ge( std::shared_ptr rhs) { - return std::make_shared( - shared_from_this(), rhs, Domain(0, std::numeric_limits::max())); + return std::make_shared(shared_from_this(), rhs, + Domain(0, kint64max)); } std::shared_ptr LinearExpr::GeCst(int64_t rhs) { - return std::make_shared( - shared_from_this(), Domain(rhs, std::numeric_limits::max())); + return std::make_shared(shared_from_this(), + Domain(rhs, kint64max)); } std::shared_ptr LinearExpr::Gt( std::shared_ptr rhs) { - return std::make_shared( - shared_from_this(), rhs, Domain(1, std::numeric_limits::max())); + return std::make_shared(shared_from_this(), rhs, + Domain(1, kint64max)); } std::shared_ptr LinearExpr::GtCst(int64_t rhs) { - return std::make_shared( - shared_from_this(), Domain(rhs + 1, std::numeric_limits::max())); + return std::make_shared(shared_from_this(), + Domain(rhs + 1, kint64max)); } void IntExprVisitor::AddToProcess(std::shared_ptr expr, @@ -1015,13 +1016,13 @@ std::string BoundedLinearExpression::ToString() const { if (bounds_.IsFixed()) { absl::StrAppend(&s, " == ", bounds_.Min()); } else if (bounds_.NumIntervals() == 1) { - if (bounds_.Min() == std::numeric_limits::min()) { - if (bounds_.Max() == std::numeric_limits::max()) { + if (bounds_.Min() == kint64min) { + if (bounds_.Max() == kint64max) { return absl::StrCat("True (unbounded expr ", s, ")"); } else { absl::StrAppend(&s, " <= ", bounds_.Max()); } - } else if (bounds_.Max() == std::numeric_limits::max()) { + } else if (bounds_.Max() == kint64max) { absl::StrAppend(&s, " >= ", bounds_.Min()); } else { return absl::StrCat(bounds_.Min(), " <= ", s, " <= ", bounds_.Max()); diff --git a/ortools/sat/python/pybind_constraint.cc b/ortools/sat/python/pybind_constraint.cc new file mode 100644 index 00000000000..fb59a2b9c7c --- /dev/null +++ b/ortools/sat/python/pybind_constraint.cc @@ -0,0 +1,1115 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/python/pybind_constraint.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "ortools/base/types.h" +#include "ortools/port/proto_utils.h" +#include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/python/linear_expr.h" +#include "ortools/util/saturated_arithmetic.h" +#include "ortools/util/sorted_interval_list.h" +#include "pybind11/attr.h" +#include "pybind11/cast.h" +#include "pybind11/functional.h" +#include "pybind11/gil.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace operations_research::sat::python { + +namespace { + +namespace py = pybind11; + +void ThrowError(PyObject* py_exception, const std::string& message) { + PyErr_SetString(py_exception, message.c_str()); + throw py::error_already_set(); +} + +class Constraint; +class IntervalVar; + +enum class BoolArgumentConstraint { + kAtMostOne, + kBoolAnd, + kBoolOr, + kBoolXor, + kExactlyOne, +}; + +enum class LinearArgumentConstraint { + kDiv, + kMax, + kMin, + kMod, + kProd, +}; + +class CpBaseModel : public std::enable_shared_from_this { + public: + explicit CpBaseModel(std::shared_ptr model_proto) + : model_proto_(model_proto == nullptr ? std::make_shared() + : model_proto), + numpy_bool_type_(py::dtype::of().attr("type").cast()) { + if (model_proto != nullptr) RebuildConstantMap(); + } + + std::shared_ptr model_proto() const { return model_proto_; } + + int GetOrMakeIndexFromConstant(int64_t value) { + auto it = cache_.find(value); + if (it != cache_.end()) return it->second; + const int index = model_proto_->variables_size(); + IntegerVariableProto* const_var = model_proto_->add_variables(); + const_var->add_domain(value); + const_var->add_domain(value); + cache_[value] = index; + return index; + } + + void RebuildConstantMap() { + cache_.clear(); + for (int i = 0; i < model_proto_->variables_size(); ++i) { + const IntegerVariableProto& var = model_proto_->variables(i); + if (var.domain_size() == 2 && var.domain(0) == var.domain(1) && + var.name().empty()) { // Constants do not have names. + cache_[var.domain(0)] = i; + } + } + } + + int GetOrMakeBooleanIndex(py::handle literal) { + if (py::isinstance(literal)) { + std::shared_ptr var = literal.cast>(); + AssertVariableIsBoolean(var); + return var->index(); + } else if (py::isinstance(literal)) { + std::shared_ptr not_var = + literal.cast>(); + AssertVariableIsBoolean(not_var); + return not_var->index(); + } else if (IsBooleanValue(literal)) { + const bool value = literal.cast(); + if (value) { + return GetOrMakeIndexFromConstant(1); + } else { + return GetOrMakeIndexFromConstant(0); + } + } else if (py::isinstance(literal)) { + const int64_t value = literal.cast(); + if (value == 1 || value == -1) { // -1 = ~False. + return GetOrMakeIndexFromConstant(1); + } + if (value == 0 || value == -2) { // -2 = ~True. + return GetOrMakeIndexFromConstant(0); + } + ThrowError(PyExc_TypeError, absl::StrCat("Invalid literal: ", value)); + } else { + py::type objtype = py::type::of(literal); + const std::string type_name = + objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: '", + absl::CEscape(type_name), "'")); + } + return 0; // Unreachable. + } + + int GetOrMakeVariableIndex(py::handle arg) { + if (py::isinstance(arg)) { + std::shared_ptr var = arg.cast>(); + return var->index(); + } else if (py::isinstance(arg)) { + return GetOrMakeIndexFromConstant(arg.cast()); + } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer") && + getattr(arg, "is_integer")().cast()) { + return GetOrMakeIndexFromConstant(arg.cast()); + } else { + py::type objtype = py::type::of(arg); + const std::string type_name = + objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, + absl::StrCat("GetOrMakeVariableIndex() only accept integer " + "variables or constants as argument: '", + absl::CEscape(type_name), "'")); + } + return 0; // Unreachable. + } + + void AssertVariableIsBoolean(std::shared_ptr literal) { + if (PositiveRef(literal->index()) >= model_proto_->variables_size()) { + ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: ", + literal->ToString())); + } + IntegerVariableProto* var = + model_proto_->mutable_variables(PositiveRef(literal->index())); + if (var->domain_size() != 2 || var->domain(0) < 0 || var->domain(1) > 1) { + ThrowError(PyExc_TypeError, absl::StrCat("Invalid boolean literal: ", + literal->ToString())); + } + } + + bool IsBooleanValue(py::handle value) { + return py::isinstance(value) || + py::isinstance(value, numpy_bool_type_); + } + + std::shared_ptr AddAllDifferentInternal(py::args exprs); + + std::shared_ptr AddAutomatonInternal( + py::sequence transition_expressions, int64_t starting_state, + const std::vector& final_states, + const std::vector>& transition_triples); + + std::shared_ptr AddBoolArgumentConstraintInternal( + BoolArgumentConstraint type, py::args literals); + + std::shared_ptr AddBoundedLinearExpressionInternal( + BoundedLinearExpression* ble); + + std::shared_ptr AddElementInternal(const py::handle& index, + py::sequence exprs, + const py::handle& target); + + std::shared_ptr AddInverseInternal(py::sequence direct, + py::sequence inverse); + + std::shared_ptr AddLinearArgumentConstraintInternal( + LinearArgumentConstraint type, const py::handle& target, py::args exprs); + + std::shared_ptr AddReservoirInternal(py::sequence times, + py::sequence level_changes, + py::sequence actives, + int64_t min_level, + int64_t max_level); + + std::shared_ptr AddTableInternal( + py::sequence exprs, const std::vector>& tuples, + bool negated); + + std::shared_ptr NewIntervalVarInternal(const std::string& name, + const py::handle& start, + const py::handle& size, + const py::handle& end, + py::sequence literals); + + std::shared_ptr AddNoOverlapInternal( + const std::vector>& intervals); + + std::shared_ptr AddNoOverlap2DInternal( + const std::vector>& x_intervals, + const std::vector>& y_intervals); + + std::shared_ptr AddCumulativeInternal( + const std::vector>& intervals, + py::sequence demands, const py::handle& capacity); + + std::shared_ptr AddCircuitInternal( + const std::vector>& arcs); + + std::shared_ptr AddRoutesInternal( + const std::vector>& arcs); + + private: + std::shared_ptr model_proto_; + absl::flat_hash_map cache_; + py::type numpy_bool_type_; +}; + +void LinearExprToProto(const py::handle& arg, int64_t multiplier, + LinearExpressionProto* proto) { + proto->Clear(); + if (py::isinstance(arg)) { + std::shared_ptr expr = arg.cast>(); + IntExprVisitor visitor; + visitor.AddToProcess(expr, multiplier); + std::vector> vars; + std::vector coeffs; + int64_t offset = 0; + if (!visitor.Process(&vars, &coeffs, &offset)) { + ThrowError(PyExc_ValueError, + absl::StrCat("Failed to convert integer linear expression: ", + expr->DebugString())); + } + proto->mutable_vars()->Reserve(vars.size()); + for (const auto& var : vars) { + proto->add_vars(var->index()); + } + proto->mutable_coeffs()->Reserve(coeffs.size()); + for (const int64_t coeff : coeffs) { + proto->add_coeffs(coeff); + } + proto->set_offset(offset); + } else if (py::isinstance(arg)) { + int64_t value = arg.cast(); + proto->set_offset(value * multiplier); + } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer") && + getattr(arg, "is_integer")().cast()) { + int64_t value = arg.cast(); + proto->set_offset(value * multiplier); + } else { + py::type objtype = py::type::of(arg); + const std::string type_name = objtype.attr("__name__").cast(); + py::print(arg); + ThrowError(PyExc_TypeError, + absl::StrCat("Cannot convert '", absl::CEscape(type_name), + "' to a linear expression.")); + } +} + +std::shared_ptr CpBaseModel::AddAllDifferentInternal( + py::args exprs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + if (exprs.size() == 1 && py::isinstance(exprs[0])) { + for (const auto& expr : exprs[0]) { + LinearExprToProto(expr, 1, ct->mutable_all_diff()->add_exprs()); + } + } else { + for (const auto& expr : exprs) { + LinearExprToProto(expr, 1, ct->mutable_all_diff()->add_exprs()); + } + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddAutomatonInternal( + py::sequence transition_expressions, int64_t starting_state, + const std::vector& final_states, + const std::vector>& transition_triples) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + for (const auto& expr : transition_expressions) { + LinearExprToProto(expr, 1, ct->mutable_automaton()->add_exprs()); + } + ct->mutable_automaton()->set_starting_state(starting_state); + ct->mutable_automaton()->mutable_final_states()->Add(final_states.begin(), + final_states.end()); + for (const auto& tuple : transition_triples) { + if (tuple.size() != 3) { + ThrowError(PyExc_ValueError, + absl::StrCat("transition (", absl::StrJoin(tuple, ","), + ") has the wrong arity != 3")); + } + ct->mutable_automaton()->add_transition_tail(tuple[0]); + ct->mutable_automaton()->add_transition_label(tuple[1]); + ct->mutable_automaton()->add_transition_head(tuple[2]); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddBoolArgumentConstraintInternal( + BoolArgumentConstraint type, py::args literals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + BoolArgumentProto* proto = nullptr; + switch (type) { + case BoolArgumentConstraint::kAtMostOne: + proto = ct->mutable_at_most_one(); + break; + case BoolArgumentConstraint::kBoolAnd: + proto = ct->mutable_bool_and(); + break; + case BoolArgumentConstraint::kBoolOr: + proto = ct->mutable_bool_or(); + break; + case BoolArgumentConstraint::kBoolXor: + proto = ct->mutable_bool_xor(); + break; + case BoolArgumentConstraint::kExactlyOne: + proto = ct->mutable_exactly_one(); + break; + default: + ThrowError(PyExc_ValueError, + absl::StrCat("Unknown boolean argument constraint: ", type)); + } + if (literals.size() == 1 && py::isinstance(literals[0])) { + for (const auto& literal : literals[0]) { + proto->add_literals(GetOrMakeBooleanIndex(literal)); + } + } else { + for (const auto& literal : literals) { + proto->add_literals(GetOrMakeBooleanIndex(literal)); + } + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddBoundedLinearExpressionInternal( + BoundedLinearExpression* ble) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + for (const auto& var : ble->vars()) { + ct->mutable_linear()->add_vars(var->index()); + } + for (const int64_t coeff : ble->coeffs()) { + ct->mutable_linear()->add_coeffs(coeff); + } + const int64_t offset = ble->offset(); + const Domain& bounds = ble->bounds(); + for (const int64_t bound : bounds.FlattenedIntervals()) { + if (bound == kint64min || bound == kint64max) { + ct->mutable_linear()->add_domain(bound); + } else { + ct->mutable_linear()->add_domain(CapSub(bound, offset)); + } + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddElementInternal( + const py::handle& index, py::sequence exprs, const py::handle& target) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + LinearExprToProto(index, 1, ct->mutable_element()->mutable_linear_index()); + for (const auto& expr : exprs) { + LinearExprToProto(expr, 1, ct->mutable_element()->add_exprs()); + } + LinearExprToProto(target, 1, ct->mutable_element()->mutable_linear_target()); + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddInverseInternal( + py::sequence direct, py::sequence inverse) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + ct->mutable_inverse()->mutable_f_direct()->Reserve(direct.size()); + for (const auto& var : direct) { + ct->mutable_inverse()->add_f_direct(GetOrMakeVariableIndex(var)); + } + ct->mutable_inverse()->mutable_f_inverse()->Reserve(inverse.size()); + for (const auto& var : inverse) { + ct->mutable_inverse()->add_f_inverse(GetOrMakeVariableIndex(var)); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddLinearArgumentConstraintInternal( + LinearArgumentConstraint type, const py::handle& target, py::args exprs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + LinearArgumentProto* proto; + int64_t multiplier = 1; + switch (type) { + case LinearArgumentConstraint::kDiv: + proto = ct->mutable_int_div(); + break; + case LinearArgumentConstraint::kMax: + proto = ct->mutable_lin_max(); + break; + case LinearArgumentConstraint::kMin: + proto = ct->mutable_lin_max(); + multiplier = -1; + break; + case LinearArgumentConstraint::kMod: + proto = ct->mutable_int_mod(); + break; + case LinearArgumentConstraint::kProd: + proto = ct->mutable_int_prod(); + break; + default: + ThrowError(PyExc_ValueError, + absl::StrCat("Unknown integer argument constraint: ", type)); + } + + LinearExprToProto(target, multiplier, proto->mutable_target()); + + if (exprs.size() == 1 && py::isinstance(exprs[0])) { + for (const auto& expr : exprs[0]) { + LinearExprToProto(expr, multiplier, proto->add_exprs()); + } + } else { + for (const auto& expr : exprs) { + LinearExprToProto(expr, multiplier, proto->add_exprs()); + } + } + + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddReservoirInternal( + py::sequence times, py::sequence level_changes, py::sequence actives, + int64_t min_level, int64_t max_level) { + const int ct_index = model_proto_->constraints_size(); + ReservoirConstraintProto* proto = + model_proto_->add_constraints()->mutable_reservoir(); + proto->mutable_time_exprs()->Reserve(times.size()); + for (const auto& time : times) { + LinearExprToProto(time, 1, proto->add_time_exprs()); + } + proto->mutable_level_changes()->Reserve(level_changes.size()); + for (const auto& change : level_changes) { + LinearExprToProto(change, 1, proto->add_level_changes()); + } + proto->mutable_active_literals()->Reserve(actives.size()); + for (const auto& active : actives) { + proto->add_active_literals(GetOrMakeBooleanIndex(active)); + } + proto->set_min_level(min_level); + proto->set_max_level(max_level); + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddTableInternal( + py::sequence exprs, const std::vector>& tuples, + bool negated) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + const int num_exprs = exprs.size(); + for (const auto& expr : exprs) { + LinearExprToProto(expr, 1, ct->mutable_table()->add_exprs()); + } + for (const auto& tuple : tuples) { + if (tuple.size() != num_exprs) { + ThrowError(PyExc_ValueError, + absl::StrCat("Tuple (", absl::StrJoin(tuple, ","), + ") has the wrong arity != ", num_exprs)); + } + ct->mutable_table()->mutable_values()->Add(tuple.begin(), tuple.end()); + } + ct->mutable_table()->set_negated(negated); + return std::make_shared(shared_from_this(), ct_index); +} + +std::string ShortName(int literal, std::shared_ptr model_proto) { + const int var = PositiveRef(literal); + const IntegerVariableProto& var_proto = model_proto->variables(var); + const std::string& var_name = + var_proto.name().empty() ? absl::StrCat("i", var) : var_proto.name(); + if (literal < 0) { + return absl::StrCat("not(", var_name, ")"); + } else { + return var_name; + } +} + +std::string ShortExprName(const LinearExpressionProto& expr, + std::shared_ptr model_proto) { + if (expr.vars().empty()) { + return absl::StrCat(expr.offset()); + } else { + const IntegerVariableProto& var_proto = + model_proto->variables(expr.vars(0)); + const std::string& var_name = var_proto.name().empty() + ? absl::StrCat("i", expr.vars(0)) + : var_proto.name(); + const int64_t coeff = expr.coeffs(0); + std::string result; + if (coeff == 1) { + result = var_name; + } else if (coeff == -1) { + result = absl::StrCat("-", var_name); + } else if (coeff != 0) { + result = absl::StrCat(coeff, " * ", var_name); + } + if (expr.offset() > 0) { + absl::StrAppend(&result, " + ", expr.offset()); + } else if (expr.offset() < 0) { + absl::StrAppend(&result, " - ", -expr.offset()); + } + return result; + } +} + +std::shared_ptr RebuildFromLinearExpressionProto( + const LinearExpressionProto& proto, + std::shared_ptr model_proto) { + if (proto.vars().empty()) { + return LinearExpr::ConstantInt(proto.offset()); + } else if (proto.vars_size() == 1) { + return LinearExpr::AffineInt( + std::make_shared(model_proto, proto.vars(0)), proto.coeffs(0), + proto.offset()); + } else { + std::vector> vars; + vars.reserve(proto.vars_size()); + for (const int var : proto.vars()) { + vars.push_back(std::make_shared(model_proto, var)); + } + return std::make_shared(vars, proto.coeffs(), + proto.offset()); + } +} + +class IntervalVar { + public: + IntervalVar(std::shared_ptr model_proto, int index) + : model_proto_(model_proto), index_(index) { + DCHECK_GE(index, 0); + } + + int index() const { return index_; } + + std::shared_ptr model_proto() const { return model_proto_; } + + ConstraintProto* proto() const { + return model_proto_->mutable_constraints(index_); + } + + std::string ToString() const { + const std::string name = proto()->name(); + if (name.empty()) { + return absl::StrCat("iv", index_); + } else { + return name; + } + } + + std::string DebugString() const { + if (proto()->enforcement_literal().empty()) { + return absl::StrCat( + name(), "(start = ", + ShortExprName(proto()->interval().start(), model_proto()), + ", size = ", ShortExprName(proto()->interval().size(), model_proto()), + ", end = ", ShortExprName(proto()->interval().end(), model_proto()), + ")"); + } else { + return absl::StrCat( + name(), "(start = ", + ShortExprName(proto()->interval().start(), model_proto()), + ", size = ", ShortExprName(proto()->interval().size(), model_proto()), + ", end = ", ShortExprName(proto()->interval().end(), model_proto()), + ", is_present = ", + ShortName(proto()->enforcement_literal(0), model_proto()), ")"); + } + } + + std::string name() const { return proto()->name(); } + + void SetName(const std::string& name) { proto()->set_name(name); } + + std::shared_ptr StartExpr() const { + return RebuildFromLinearExpressionProto(proto()->interval().start(), + model_proto_); + } + std::shared_ptr SizeExpr() const { + return RebuildFromLinearExpressionProto(proto()->interval().size(), + model_proto_); + } + std::shared_ptr EndExpr() const { + return RebuildFromLinearExpressionProto(proto()->interval().end(), + model_proto_); + } + + std::vector> PresenceLiterals() const { + std::vector> literals; + for (const int lit : proto()->enforcement_literal()) { + if (RefIsPositive(lit)) { + literals.push_back(std::make_shared(model_proto_, lit)); + } else { + literals.push_back(std::make_shared( + model_proto_, NegatedRef(lit))); + } + } + return literals; + } + + private: + std::shared_ptr model_proto_; + int index_; +}; + +std::shared_ptr CpBaseModel::NewIntervalVarInternal( + const std::string& name, const py::handle& start, const py::handle& size, + const py::handle& end, py::sequence literals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + if (!name.empty()) ct->set_name(name); + LinearExprToProto(start, 1, ct->mutable_interval()->mutable_start()); + LinearExprToProto(size, 1, ct->mutable_interval()->mutable_size()); + LinearExprToProto(end, 1, ct->mutable_interval()->mutable_end()); + for (const auto& lit : literals) { + ct->add_enforcement_literal(GetOrMakeBooleanIndex(lit)); + } + const std::string method = literals.empty() + ? "cp_model.new_interval_var" + : "cp_model.new_optional_interval_var"; + if (ct->interval().start().vars_size() > 1) { + ThrowError(PyExc_TypeError, + absl::StrCat(method, ": start must be affine or constant.")); + } + if (ct->interval().size().vars_size() > 1) { + ThrowError(PyExc_TypeError, + absl::StrCat(method, ": size must be affine or constant.")); + } + if (ct->interval().end().vars_size() > 1) { + ThrowError(PyExc_TypeError, + absl::StrCat(method, ": end must be affine or constant.")); + } + return std::make_shared(model_proto_, ct_index); +} + +std::shared_ptr CpBaseModel::AddNoOverlapInternal( + const std::vector>& intervals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + ct->mutable_no_overlap()->mutable_intervals()->Reserve(intervals.size()); + for (const std::shared_ptr& interval : intervals) { + ct->mutable_no_overlap()->add_intervals(interval->index()); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddNoOverlap2DInternal( + const std::vector>& x_intervals, + const std::vector>& y_intervals) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + ct->mutable_no_overlap_2d()->mutable_x_intervals()->Reserve( + x_intervals.size()); + for (const std::shared_ptr& x_interval : x_intervals) { + ct->mutable_no_overlap_2d()->add_x_intervals(x_interval->index()); + } + ct->mutable_no_overlap_2d()->mutable_y_intervals()->Reserve( + y_intervals.size()); + for (const std::shared_ptr& y_interval : y_intervals) { + ct->mutable_no_overlap_2d()->add_y_intervals(y_interval->index()); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddCumulativeInternal( + const std::vector>& intervals, + const py::sequence demands, const py::handle& capacity) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + CumulativeConstraintProto* proto = ct->mutable_cumulative(); + + proto->mutable_intervals()->Reserve(intervals.size()); + for (const std::shared_ptr& interval : intervals) { + proto->add_intervals(interval->index()); + } + + proto->mutable_demands()->Reserve(demands.size()); + for (const auto& expr : demands) { + LinearExprToProto(expr, 1, proto->add_demands()); + } + + LinearExprToProto(capacity, 1, proto->mutable_capacity()); + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddCircuitInternal( + const std::vector>& arcs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + CircuitConstraintProto* proto = ct->mutable_circuit(); + proto->mutable_tails()->Reserve(arcs.size()); + proto->mutable_heads()->Reserve(arcs.size()); + proto->mutable_literals()->Reserve(arcs.size()); + for (const auto& [tail, head, lit] : arcs) { + proto->add_tails(tail); + proto->add_heads(head); + proto->add_literals(GetOrMakeBooleanIndex(lit)); + } + return std::make_shared(shared_from_this(), ct_index); +} + +std::shared_ptr CpBaseModel::AddRoutesInternal( + const std::vector>& arcs) { + const int ct_index = model_proto_->constraints_size(); + ConstraintProto* ct = model_proto_->add_constraints(); + RoutesConstraintProto* proto = ct->mutable_routes(); + proto->mutable_tails()->Reserve(arcs.size()); + proto->mutable_heads()->Reserve(arcs.size()); + proto->mutable_literals()->Reserve(arcs.size()); + for (const auto& [tail, head, lit] : arcs) { + proto->add_tails(tail); + proto->add_heads(head); + proto->add_literals(GetOrMakeBooleanIndex(lit)); + } + return std::make_shared(shared_from_this(), ct_index); +} + +class Constraint { + public: + // We need to store the CpBaseModel to convert enforcement literals to + // indices. + Constraint(std::shared_ptr model, int index) + : model_(model), index_(index) {} + + int index() const { return index_; } + + std::shared_ptr model_proto() const { + return model_->model_proto(); + } + + ConstraintProto* proto() const { + return model_->model_proto()->mutable_constraints(index_); + } + + std::shared_ptr model() const { return model_; } + + std::string name() const { return proto()->name(); } + void SetName(const std::string& name) { proto()->set_name(name); } + void ClearName() { proto()->clear_name(); } + + std::string ToString() const { + return absl::StrCat("Constraint(index=", index_, ", ", + ProtobufDebugString(*proto()), ")"); + } + + private: + std::shared_ptr model_; + int index_; +}; + +} // namespace + +void DefinePybindWrapperForConstraints(py::module& m) { + py::enum_(m, "BoolArgumentConstraint") + .value("at_most_one", BoolArgumentConstraint::kAtMostOne) + .value("bool_and", BoolArgumentConstraint::kBoolAnd) + .value("bool_or", BoolArgumentConstraint::kBoolOr) + .value("bool_xor", BoolArgumentConstraint::kBoolXor) + .value("exactly_one", BoolArgumentConstraint::kExactlyOne) + .export_values(); + + py::enum_(m, "LinearArgumentConstraint") + .value("div", LinearArgumentConstraint::kDiv) + .value("max", LinearArgumentConstraint::kMax) + .value("min", LinearArgumentConstraint::kMin) + .value("mod", LinearArgumentConstraint::kMod) + .value("prod", LinearArgumentConstraint::kProd) + .export_values(); + + py::class_>( + m, "CpBaseModel", "Base class for the CP model.") + .def(py::init>()) + .def_property_readonly("model_proto", &CpBaseModel::model_proto, + "Returns the CP model protobuf") + .def("get_or_make_index_from_constant", + &CpBaseModel::GetOrMakeIndexFromConstant, py::arg("value"), + "Returns the index of the given constant value.") + .def("get_or_make_boolean_index", &CpBaseModel::GetOrMakeBooleanIndex, + py::arg("value"), "Returns the index of the given boolean value.") + .def("get_or_make_variable_index", &CpBaseModel::GetOrMakeVariableIndex, + py::arg("arg"), + "Returns the index of the given variable or constant variable.") + .def("is_boolean_value", &CpBaseModel::IsBooleanValue, py::arg("value")) + .def("rebuild_constant_map", &CpBaseModel::RebuildConstantMap) + .def("_add_all_different", &CpBaseModel::AddAllDifferentInternal) + .def("_add_automaton", &CpBaseModel::AddAutomatonInternal, + py::arg("transition_expressions"), py::arg("starting_state"), + py::arg("final_states"), py::arg("transition_triples")) + .def("_add_bool_argument_constraint", + &CpBaseModel::AddBoolArgumentConstraintInternal, py::arg("name")) + .def("_add_bounded_linear_expression", + &CpBaseModel::AddBoundedLinearExpressionInternal, py::arg("ble")) + .def("_add_element", &CpBaseModel::AddElementInternal, + py::arg("index").none(false), py::arg("expressions"), + py::arg("target").none(false)) + .def("_add_linear_argument_constraint", + &CpBaseModel::AddLinearArgumentConstraintInternal, + py::arg("name").none(false), py::arg("target").none(false)) + .def("_add_inverse", &CpBaseModel::AddInverseInternal, py::arg("direct"), + py::arg("inverse")) + .def("_add_reservoir", &CpBaseModel::AddReservoirInternal, + py::arg("times"), py::arg("level_changes"), py::arg("actives"), + py::arg("min_level"), py::arg("max_level")) + .def("_add_table", &CpBaseModel::AddTableInternal, py::arg("expressions"), + py::arg("values"), py::arg("negated")) + // Scheduling support. + .def("_new_interval_var", &CpBaseModel::NewIntervalVarInternal, + py::arg("name"), py::arg("start"), py::arg("size"), py::arg("end"), + py::arg("Literals")) + .def("_add_no_overlap", &CpBaseModel::AddNoOverlapInternal, + py::arg("intervals")) + .def("_add_no_overlap_2d", &CpBaseModel::AddNoOverlap2DInternal, + py::arg("x_intervals"), py::arg("y_intervals")) + .def("_add_cumulative", &CpBaseModel::AddCumulativeInternal, + py::arg("intervals"), py::arg("demands"), py::arg("capacity")) + // Routing support. + .def("_add_circuit", &CpBaseModel::AddCircuitInternal, py::arg("arcs")) + .def("_add_routes", &CpBaseModel::AddRoutesInternal, py::arg("arcs")); + + static const char* kConstraintDoc = R"doc( + Base class for constraints. + + Constraints are built by the CpModel through the add methods. + Once created by the CpModel class, they are automatically added to the model. + The purpose of this class is to allow specification of enforcement literals + for this constraint. + + b = model.new_bool_var('b') + x = model.new_int_var(0, 10, 'x') + y = model.new_int_var(0, 10, 'y') + + model.add(x + 2 * y == 5).only_enforce_if(b.negated()) + )doc"; + + static const char* kConstraintOnlyEnforceIfDoc = R"doc( + Adds one or more enforcement literals to the constraint. + + This method adds one or more literals (that is, a boolean variable or its + negation) as enforcement literals. The conjunction of all these literals + determines whether the constraint is active or not. It acts as an + implication, so if the conjunction is true, it implies that the constraint + must be enforced. If it is false, then the constraint is ignored. + + BoolOr, BoolAnd, and linear constraints all support enforcement literals. + + Args: + *literals: One or more Boolean literals. + + Returns: + self.)doc"; + + py::class_>(m, "Constraint", + kConstraintDoc) + .def(py::init, int>()) + .def_property_readonly( + "index", &Constraint::index, + "Returns the index of the constraint in the model protobuf.") + .def_property_readonly("model_proto", &Constraint::model_proto, + "Returns the model protobuf.") + .def_property_readonly("proto", &Constraint::proto, + py::return_value_policy::reference_internal, + "Returns the ConstraintProto of this constraint.") + .def_property("name", &Constraint::name, &Constraint::SetName, + "The name of the constraint.") + .def( + "with_name", + [](Constraint* self, const std::string& name) { + if (name.empty()) { + self->ClearName(); + } else { + self->SetName(name); + } + return self; + }, + "Sets the name of the constraint and returns the constraints") + .def( + "only_enforce_if", + [](std::shared_ptr self, + std::shared_ptr literal) { + self->proto()->add_enforcement_literal(literal->index()); + return self; + }, + py::arg("literal"), kConstraintOnlyEnforceIfDoc) + .def( + "only_enforce_if", + [](std::shared_ptr self, bool literal) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeIndexFromConstant(literal)); + return self; + }, + py::arg("literal"), kConstraintOnlyEnforceIfDoc) + .def( + "only_enforce_if", + [](std::shared_ptr self, + const std::vector>& literals) { + for (const std::shared_ptr& literal : literals) { + self->proto()->add_enforcement_literal(literal->index()); + } + }, + py::arg("literals"), kConstraintOnlyEnforceIfDoc) + .def( + "only_enforce_if", + [](std::shared_ptr self, py::args literals) { + if (literals.size() == 1 && + py::isinstance(literals[0])) { + py::sequence seq = literals[0].cast(); + for (const auto& literal : seq) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } else { + for (const auto& literal : literals) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } + }, + kConstraintOnlyEnforceIfDoc) + // Pre PEP8 compatibility. + .def("Name", &Constraint::name) + .def("Index", &Constraint::index) + .def("Proto", &Constraint::proto) + .def("WithName", + [](Constraint* self, const std::string& name) { + if (name.empty()) { + self->ClearName(); + } else { + self->SetName(name); + } + return self; + }) + .def("OnlyEnforceIf", [](std::shared_ptr self, + py::args literals) { + if (literals.size() == 1 && py::isinstance(literals[0])) { + py::sequence seq = literals[0].cast(); + for (const auto& literal : seq) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } else { + for (const auto& literal : literals) { + self->proto()->add_enforcement_literal( + self->model()->GetOrMakeBooleanIndex(literal)); + } + } + }); + + static const char* kIntervalVarDoc = R"doc( +Represents an Interval variable. + +An interval variable is both a constraint and a variable. It is defined by +three integer variables: start, size, and end. + +It is a constraint because, internally, it enforces that start + size == end. + +It is also a variable as it can appear in specific scheduling constraints: +NoOverlap, NoOverlap2D, Cumulative. + +Optionally, an enforcement literal can be added to this constraint, in which +case these scheduling constraints will ignore interval variables with +enforcement literals assigned to false. Conversely, these constraints will +also set these enforcement literals to false if they cannot fit these +intervals into the schedule. + +Raises: + ValueError: if start, size, end are not defined, or have the wrong type. +)doc"; + + py::class_>(m, "IntervalVar", + kIntervalVarDoc) + .def(py::init, int>()) + .def_property_readonly("index", &IntervalVar::index, + "Returns the index of the interval variable.") + .def_property_readonly("model_proto", &IntervalVar::model_proto, + "Returns the model protobuf.") + .def_property_readonly("proto", &IntervalVar::proto, + py::return_value_policy::reference_internal, + "Returns the interval constraint protobuf.") + .def_property("name", &IntervalVar::name, &IntervalVar::SetName, + "The name of the interval variable.") + .def( + "start_expr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.start().vars().empty()) { + return py::cast(proto.start().offset()); + } else { + return py::cast(self->StartExpr()); + } + }, + "Returns the start expression of the interval variable.") + .def( + "size_expr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.size().vars().empty()) { + return py::cast(proto.size().offset()); + } else { + return py::cast(self->SizeExpr()); + } + }, + "Returns the size expression of the interval variable.") + .def( + "end_expr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.end().vars().empty()) { + return py::cast(proto.end().offset()); + } else { + return py::cast(self->EndExpr()); + } + }, + "Returns the end expression of the interval variable.") + .def("presence_literals", &IntervalVar::PresenceLiterals, + "Returns the list of enforcement literals of the interval variable.") + .def("__str__", &IntervalVar::ToString) + .def("__repr__", &IntervalVar::DebugString) + .def(py::pickle( + [](std::shared_ptr p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(p->model_proto(), p->index()); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2) throw std::runtime_error("Invalid state!"); + + return std::make_shared( + t[0].cast>(), t[1].cast()); + })) + // Pre PEP8 compatibility layer. + .def("Proto", &IntervalVar::proto) + .def("Index", &IntervalVar::index) + .def("Name", &IntervalVar::name) + .def("StartExpr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.start().vars().empty()) { + return py::cast(proto.start().offset()); + } else { + return py::cast(self->StartExpr()); + } + }) + .def("SizeExpr", + [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.size().vars().empty()) { + return py::cast(proto.size().offset()); + } else { + return py::cast(self->SizeExpr()); + } + }) + .def("EndExpr", [](std::shared_ptr self) -> py::object { + const IntervalConstraintProto& proto = self->proto()->interval(); + if (proto.end().vars().empty()) { + return py::cast(proto.end().offset()); + } else { + return py::cast(self->EndExpr()); + } + }); + + m.def( + "rebuild_from_linear_expression_proto", + [](const LinearExpressionProto& proto, + std::shared_ptr model_proto) -> py::object { + if (proto.vars().empty()) { + return py::cast(proto.offset()); + } else { + return py::cast(RebuildFromLinearExpressionProto(proto, model_proto)); + } + }, + py::arg("proto"), py::arg("model_proto")); + + m.def( + "prettyprint_model_proto", + [](std::shared_ptr model_proto) -> py::object { +#if defined(ORTOOLS_TARGET_OS_SUPPORTS_PROTO_DESCRIPTOR) + return py::cast(PrettyPrintModelProto(*model_proto)); +#else + throw std::runtime_error("unsupported: no proto descriptors"); +#endif // defined(ORTOOLS_TARGET_OS_SUPPORTS_PROTO_DESCRIPTOR) + }, + py::arg("model_proto")); +} + +} // namespace operations_research::sat::python diff --git a/ortools/sat/python/pybind_constraint.h b/ortools/sat/python/pybind_constraint.h new file mode 100644 index 00000000000..c5ba592552e --- /dev/null +++ b/ortools/sat/python/pybind_constraint.h @@ -0,0 +1,25 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ORTOOLS_SAT_PYTHON_PYBIND_CONSTRAINT_H_ +#define ORTOOLS_SAT_PYTHON_PYBIND_CONSTRAINT_H_ + +#include "pybind11/pybind11.h" + +namespace operations_research::sat::python { + +void DefinePybindWrapperForConstraints(pybind11::module& m); + +} // namespace operations_research::sat::python + +#endif // ORTOOLS_SAT_PYTHON_PYBIND_CONSTRAINT_H_ diff --git a/ortools/sat/python/pybind_linearexpr.cc b/ortools/sat/python/pybind_linearexpr.cc new file mode 100644 index 00000000000..00c6f098cdb --- /dev/null +++ b/ortools/sat/python/pybind_linearexpr.cc @@ -0,0 +1,1022 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/python/pybind_linearexpr.h" + +#include +#include +#include +#include +#include +#include + +#include "ortools/base/types.h" + +#if PY_VERSION_HEX >= 0x030E0000 && !defined(PYPY_VERSION) // Python >= 3.14 +#define Py_BUILD_CORE +#include "internal/pycore_frame.h" +#include "internal/pycore_interpframe.h" +#undef Py_BUILD_CORE +#endif + +#define RunningOnValgrind AbslRunningOnValgrind +#define ValgrindSlowdown AbslValgrindSlowdown +#include "absl/base/dynamic_annotations.h" // IWYU pragma: keep +#undef RunningOnValgrind +#undef ValgrindSlowdown + +#include "absl/functional/any_invocable.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "ortools/sat/python/linear_expr.h" +#include "ortools/sat/python/linear_expr_doc.h" +#include "ortools/util/sorted_interval_list.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace operations_research::sat::python { + +namespace { + +namespace py = pybind11; + +void ThrowError(PyObject* py_exception, const std::string& message) { + PyErr_SetString(py_exception, message.c_str()); + throw py::error_already_set(); +} + +void RaiseIfNone(std::shared_ptr expr) { + if (expr == nullptr) { + ThrowError(PyExc_TypeError, + "Linear constraints do not accept None as argument."); + } +} + +void ProcessExprArg( + const py::handle& arg, + absl::AnyInvocable)> on_linear_expr, + absl::AnyInvocable on_int_constant, + absl::AnyInvocable on_float_constant) { + if (py::isinstance(arg)) { + on_linear_expr(arg.cast>()); + } else if (py::isinstance(arg)) { + on_int_constant(arg.cast()); + } else if (py::isinstance(arg)) { + on_float_constant(arg.cast()); + } else if (hasattr(arg, "dtype") && hasattr(arg, "is_integer")) { + if (getattr(arg, "is_integer")().cast()) { + on_int_constant(arg.cast()); + } else { + on_float_constant(arg.cast()); + } + } else { + py::type objtype = py::type::of(arg); + const std::string type_name = objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, + absl::StrCat("LinearExpr::sum() only accept linear " + "expressions and constants as argument: '", + absl::CEscape(type_name), "'")); + } +} + +std::shared_ptr SumArguments(py::args expressions) { + std::vector> linear_exprs; + int64_t int_offset = 0; + double float_offset = 0.0; + bool has_floats = false; + + const auto process_arg = [&](const py::handle& arg) -> void { + ProcessExprArg( + arg, + [&](std::shared_ptr expr) { linear_exprs.push_back(expr); }, + [&](int64_t value) { int_offset += value; }, + [&](double value) { + if (value != 0.0) { + float_offset += value; + has_floats = true; + } + }); + }; + + if (expressions.size() == 1 && py::isinstance(expressions[0])) { + // Normal list or tuple argument. + py::sequence elements = expressions[0].cast(); + linear_exprs.reserve(elements.size()); + for (const py::handle& expr : elements) { + process_arg(expr); + } + } else { // Direct sum(x, y, 3, ..) without []. + linear_exprs.reserve(expressions.size()); + for (const py::handle expr : expressions) { + process_arg(expr); + } + } + + // If there are floats, we add the int offset to the float offset. + if (has_floats) { + float_offset += static_cast(int_offset); + int_offset = 0; + } + + if (linear_exprs.empty()) { + if (has_floats) { + return std::make_shared(float_offset); + } else { + return std::make_shared(int_offset); + } + } else if (linear_exprs.size() == 1) { + if (has_floats) { + if (float_offset == 0.0) { + return linear_exprs[0]; + } else { + return std::make_shared(linear_exprs[0], 1.0, + float_offset); + } + } else if (int_offset != 0) { + return std::make_shared(linear_exprs[0], 1, int_offset); + } else { + return linear_exprs[0]; + } + } else { + if (has_floats) { + return std::make_shared(linear_exprs, 0, float_offset); + } else { + return std::make_shared(linear_exprs, int_offset, 0.0); + } + } +} + +std::shared_ptr WeightedSumArguments(py::sequence expressions, + py::sequence coefficients) { + const int64_t size = expressions.size(); + if (size != coefficients.size()) { + ThrowError(PyExc_ValueError, + absl::StrCat("LinearExpr::weighted_sum() requires the same " + "number of arguments and coefficients: ", + size, " != ", coefficients.size())); + } + + std::vector> linear_exprs; + std::vector int_coeffs; + std::vector float_coeffs; + linear_exprs.reserve(size); + int_coeffs.reserve(size); + float_coeffs.reserve(size); + int64_t int_offset = 0; + double float_offset = 0.0; + bool has_floats = false; + bool fast_coeffs = false; + const void* raw_coeffs = nullptr; + Py_ssize_t coeff_stride = 0; + + enum { kInt64, kInt32, kDouble } coeff_type = kInt64; + + if (py::isinstance(coefficients)) { + py::array arr = coefficients.cast(); + + if (arr.ndim() == 1 && arr.size() == size) { + if (py::isinstance>(arr)) { + fast_coeffs = true; + raw_coeffs = arr.data(); + coeff_stride = arr.strides(0); + coeff_type = kInt64; + } else if (py::isinstance>(arr)) { + fast_coeffs = true; + raw_coeffs = arr.data(); + coeff_stride = arr.strides(0); + coeff_type = kInt32; + } else if (py::isinstance>(arr)) { + fast_coeffs = true; + raw_coeffs = arr.data(); + coeff_stride = arr.strides(0); + coeff_type = kDouble; + } + } + } + + for (int64_t i = 0; i < size; ++i) { + // --- Parse Coefficient --- + int64_t c_int = 0; + double c_float = 0.0; + bool c_is_float = false; + bool c_is_zero = false; + if (fast_coeffs) { + const char* ptr = static_cast(raw_coeffs) + i * coeff_stride; + if (coeff_type == kInt64) { + c_int = *reinterpret_cast(ptr); + if (c_int == 0) { + c_is_zero = true; + } else { + c_float = static_cast(c_int); + } + } else if (coeff_type == kInt32) { + c_int = *reinterpret_cast(ptr); + if (c_int == 0) { + c_is_zero = true; + } else { + c_float = static_cast(c_int); + } + } else { // kDouble + c_float = *reinterpret_cast(ptr); + if (c_float == 0.0) { + c_is_zero = true; + } else { + c_is_float = true; + has_floats = true; + } + } + } else { + const py::handle coeff_obj = coefficients[i]; + if (py::isinstance(coeff_obj)) { + c_int = coeff_obj.cast(); + if (c_int == 0) { + c_is_zero = true; + } else { + c_float = static_cast(c_int); + } + } else if (py::isinstance(coeff_obj)) { + c_float = coeff_obj.cast(); + if (c_float == 0.0) { + c_is_zero = true; + } else { + c_is_float = true; + has_floats = true; + } + } else if (hasattr(coeff_obj, "dtype") && + hasattr(coeff_obj, "is_integer")) { + if (getattr(coeff_obj, "is_integer")().cast()) { + c_int = coeff_obj.cast(); + if (c_int == 0) { + c_is_zero = true; + } else { + c_float = static_cast(c_int); + } + } else { + c_float = coeff_obj.cast(); + if (c_float == 0.0) { + c_is_zero = true; + } else { + c_is_float = true; + has_floats = true; + } + } + } else { + py::type objtype = py::type::of(coeff_obj); + const std::string type_name = + objtype.attr("__name__").cast(); + ThrowError( + PyExc_TypeError, + absl::StrCat("LinearExpr::weighted_sum() only accept constants " + "as coefficients: '", + absl::CEscape(type_name), "'")); + } + } + + if (c_is_zero) continue; + + // --- Parse Expression --- + const py::handle expr_obj = expressions[i]; + if (py::isinstance(expr_obj)) { + linear_exprs.push_back(expr_obj.cast>()); + if (c_is_float) { + float_coeffs.push_back(c_float); + } else { + int_coeffs.push_back(c_int); + float_coeffs.push_back(c_float); + } + } else if (py::isinstance(expr_obj)) { + int64_t val = expr_obj.cast(); + if (val == 0) continue; + if (c_is_float) { + float_offset += c_float * static_cast(val); + } else { + int_offset += c_int * val; + } + } else if (py::isinstance(expr_obj)) { + double val = expr_obj.cast(); + if (val == 0.0) continue; + has_floats = true; + float_offset += c_float * val; + } else if (hasattr(expr_obj, "dtype") && hasattr(expr_obj, "is_integer")) { + if (getattr(expr_obj, "is_integer")().cast()) { + int64_t val = expr_obj.cast(); + if (val == 0) continue; + if (c_is_float) { + float_offset += c_float * static_cast(val); + } else { + int_offset += c_int * val; + } + } else { + double val = expr_obj.cast(); + if (val == 0.0) continue; + has_floats = true; + float_offset += c_float * val; + } + } else { + py::type objtype = py::type::of(expr_obj); + const std::string type_name = + objtype.attr("__name__").cast(); + ThrowError(PyExc_TypeError, + absl::StrCat("LinearExpr::weighted_sum() only accept linear " + "expressions and constants as argument: '", + absl::CEscape(type_name), "'")); + } + } + + // Correct the float offset if there are int offsets. + if (has_floats) { + float_offset += static_cast(int_offset); + int_offset = 0; + } + + if (linear_exprs.empty()) { + if (has_floats) { + return std::make_shared(float_offset); + } else { + return std::make_shared(int_offset); + } + } else if (linear_exprs.size() == 1) { + if (has_floats) { + return std::make_shared(linear_exprs[0], float_coeffs[0], + float_offset); + } else if (int_offset != 0 || int_coeffs[0] != 1) { + return std::make_shared(linear_exprs[0], int_coeffs[0], + int_offset); + } else { + return linear_exprs[0]; + } + } else { + if (has_floats) { + return std::make_shared(linear_exprs, float_coeffs, + float_offset); + } else { + return std::make_shared(linear_exprs, int_coeffs, + int_offset); + } + } +} + +// Checks that the result is not null and throws an error if it is. +std::shared_ptr CheckBoundedLinearExpression( + std::shared_ptr result, + std::shared_ptr lhs, + std::shared_ptr rhs = nullptr) { + if (!result->ok()) { + if (rhs == nullptr) { + ThrowError(PyExc_TypeError, + absl::StrCat("Linear constraints only accept integer values " + "and coefficients: ", + lhs->DebugString())); + } else { + ThrowError(PyExc_TypeError, + absl::StrCat("Linear constraints only accept integer values " + "and coefficients: ", + lhs->DebugString(), " and ", rhs->DebugString())); + } + } + return result; +} + +#if PY_VERSION_HEX >= 0x030E00A7 && !defined(PYPY_VERSION) +bool was_optimized_in_function_call(PyObject* op) { + PyFrameObject* frame = PyEval_GetFrame(); + if (frame == NULL) { + return false; + } + _PyInterpreterFrame* f = frame->f_frame; + _PyStackRef* base = _PyFrame_Stackbase(f); + _PyStackRef* stackpointer = f->stackpointer; + + while (stackpointer > base) { + stackpointer--; + if (op == PyStackRef_AsPyObjectBorrow(*stackpointer)) { + // We want detect if the object is a temporary and borrowed. If so, it + // should be only referenced once in the stack, but it should not be safe. + return !PyStackRef_IsHeapSafe(*stackpointer); + } + } + return false; +} + +bool IsOnwedExclusivelyThroughPyBind11(PyObject* op) { +#if !defined(Py_GIL_DISABLED) + return Py_REFCNT(op) == 3; +#else + // NOTE: the entire ob_ref_shared field must be zero, including flags, to + // ensure that other threads cannot concurrently create new references to + // this object. + return (_Py_IsOwnedByCurrentThread(op) && + _Py_atomic_load_uint32_relaxed(&op->ob_ref_local) == 3 && + _Py_atomic_load_ssize_relaxed(&op->ob_ref_shared) == 0); +#endif +} + +template +bool IsFree(std::shared_ptr expr) { + PyObject* op = py::cast(expr).ptr(); + return IsOnwedExclusivelyThroughPyBind11(op) && + !was_optimized_in_function_call(op); +} +#else +template +bool IsFree(std::shared_ptr expr) { + return Py_REFCNT(py::cast(expr).ptr()) == 4; +} +#endif + +} // namespace + +void DefinePybindWrapperForLinearExpr(py::module& m) { + py::class_>( + m, "LinearExpr", DOC(operations_research, sat, python, LinearExpr)) + .def_static("sum", &SumArguments, "Returns the sum(expressions).") + .def_static("weighted_sum", &WeightedSumArguments, py::arg("expressions"), + py::arg("coefficients"), + "Returns the sum of (expressions[i] * coefficients[i])") + .def_static("term", &LinearExpr::TermInt, py::arg("expr").none(false), + py::arg("coeff"), + DOC(operations_research, sat, python, LinearExpr, TermInt)) + .def_static("term", &LinearExpr::TermFloat, py::arg("expr").none(false), + py::arg("coeff"), + DOC(operations_research, sat, python, LinearExpr, TermFloat)) + .def_static("affine", &LinearExpr::AffineInt, py::arg("expr").none(false), + py::arg("coeff"), py::arg("offset"), + DOC(operations_research, sat, python, LinearExpr, AffineInt)) + .def_static( + "affine", &LinearExpr::AffineFloat, py::arg("expr").none(false), + py::arg("coeff"), py::arg("offset"), + DOC(operations_research, sat, python, LinearExpr, AffineFloat)) + .def_static( + "constant", &LinearExpr::ConstantInt, py::arg("value"), + DOC(operations_research, sat, python, LinearExpr, ConstantInt)) + .def_static( + "constant", &LinearExpr::ConstantFloat, py::arg("value"), + DOC(operations_research, sat, python, LinearExpr, ConstantFloat)) + // Pre PEP8 compatibility layer. + .def_static("Sum", &SumArguments) + .def_static("WeightedSum", &WeightedSumArguments, py::arg("expressions"), + py::arg("coefficients")) + .def_static("Term", &LinearExpr::TermInt, py::arg("expr").none(false), + py::arg("coeff"), "Returns expr * coeff.") + .def_static("Term", &LinearExpr::TermFloat, py::arg("expr").none(false), + py::arg("coeff"), "Returns expr * coeff.") + // Methods. + .def("__str__", + [](std::shared_ptr expr) -> std::string { + return expr->ToString(); + }) + .def("__repr__", + [](std::shared_ptr expr) -> std::string { + return expr->DebugString(); + }) + .def( + "is_integer", + [](std::shared_ptr expr) -> bool { + return expr->IsInteger(); + }, + DOC(operations_research, sat, python, LinearExpr, IsInteger)) + // Operators. + .def("__add__", &LinearExpr::Add, py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, Add)) + .def("__add__", &LinearExpr::AddInt, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddInt)) + .def("__add__", &LinearExpr::AddFloat, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddFloat)) + .def("__radd__", &LinearExpr::AddInt, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddInt)) + .def("__radd__", &LinearExpr::AddFloat, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddFloat)) + .def("__sub__", &LinearExpr::Sub, py::arg("h").none(false), + DOC(operations_research, sat, python, LinearExpr, Sub)) + .def("__sub__", &LinearExpr::SubInt, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, SubInt)) + .def("__sub__", &LinearExpr::SubFloat, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, SubFloat)) + .def("__rsub__", &LinearExpr::RSub, py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, RSub)) + .def("__rsub__", &LinearExpr::RSubInt, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, RSubInt)) + .def("__rsub__", &LinearExpr::RSubFloat, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, RSubFloat)) + .def("__mul__", &LinearExpr::MulInt, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulInt)) + .def("__mul__", &LinearExpr::MulFloat, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulFloat)) + .def("__rmul__", &LinearExpr::MulInt, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulInt)) + .def("__rmul__", &LinearExpr::MulFloat, py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, MulFloat)) + .def("__neg__", &LinearExpr::Neg, + DOC(operations_research, sat, python, LinearExpr, Neg)) + .def( + "__eq__", + [](std::shared_ptr lhs, std::shared_ptr rhs) { + RaiseIfNone(rhs); + return CheckBoundedLinearExpression(lhs->Eq(rhs), lhs, rhs); + }, + DOC(operations_research, sat, python, LinearExpr, Eq)) + .def( + "__eq__", + [](std::shared_ptr lhs, int64_t rhs) { + if (rhs == kint64max || rhs == kint64min) { + ThrowError(PyExc_ValueError, + "== INT_MIN or INT_MAX is not supported"); + } + return CheckBoundedLinearExpression(lhs->EqCst(rhs), lhs); + }, + DOC(operations_research, sat, python, LinearExpr, EqCst)) + .def( + "__ne__", + [](std::shared_ptr lhs, std::shared_ptr rhs) { + RaiseIfNone(rhs); + return CheckBoundedLinearExpression(lhs->Ne(rhs), lhs, rhs); + }, + DOC(operations_research, sat, python, LinearExpr, Ne)) + .def( + "__ne__", + [](std::shared_ptr lhs, int64_t rhs) { + return CheckBoundedLinearExpression(lhs->NeCst(rhs), lhs); + }, + DOC(operations_research, sat, python, LinearExpr, NeCst)) + .def( + "__le__", + [](std::shared_ptr lhs, std::shared_ptr rhs) { + RaiseIfNone(rhs); + return CheckBoundedLinearExpression(lhs->Le(rhs), lhs, rhs); + }, + DOC(operations_research, sat, python, LinearExpr, Le)) + .def( + "__le__", + [](std::shared_ptr lhs, int64_t rhs) { + if (rhs == kint64min) { + ThrowError(PyExc_ArithmeticError, "<= INT_MIN is not supported"); + } + return CheckBoundedLinearExpression(lhs->LeCst(rhs), lhs); + }, + DOC(operations_research, sat, python, LinearExpr, LeCst)) + .def( + "__lt__", + [](std::shared_ptr lhs, std::shared_ptr rhs) { + RaiseIfNone(rhs); + return CheckBoundedLinearExpression(lhs->Lt(rhs), lhs, rhs); + }, + DOC(operations_research, sat, python, LinearExpr, Lt)) + .def( + "__lt__", + [](std::shared_ptr lhs, int64_t rhs) { + if (rhs == kint64min) { + ThrowError(PyExc_ArithmeticError, "< INT_MIN is not supported"); + } + return CheckBoundedLinearExpression(lhs->LtCst(rhs), lhs); + }, + DOC(operations_research, sat, python, LinearExpr, LtCst)) + .def( + "__ge__", + [](std::shared_ptr lhs, std::shared_ptr rhs) { + RaiseIfNone(rhs); + return CheckBoundedLinearExpression(lhs->Ge(rhs), lhs, rhs); + }, + DOC(operations_research, sat, python, LinearExpr, Ge)) + .def( + "__ge__", + [](std::shared_ptr lhs, int64_t rhs) { + if (rhs == kint64max) { + ThrowError(PyExc_ArithmeticError, ">= INT_MAX is not supported"); + } + return CheckBoundedLinearExpression(lhs->GeCst(rhs), lhs); + }, + DOC(operations_research, sat, python, LinearExpr, GeCst)) + .def( + "__gt__", + [](std::shared_ptr lhs, std::shared_ptr rhs) { + RaiseIfNone(rhs); + return CheckBoundedLinearExpression(lhs->Gt(rhs), lhs, rhs); + }, + DOC(operations_research, sat, python, LinearExpr, Gt)) + .def( + "__gt__", + [](std::shared_ptr lhs, int64_t rhs) { + if (rhs == kint64max) { + ThrowError(PyExc_ArithmeticError, "> INT_MAX is not supported"); + } + return CheckBoundedLinearExpression(lhs->GtCst(rhs), lhs); + }, + DOC(operations_research, sat, python, LinearExpr, GtCst)) + // Disable other operators as they are not supported. + .def("__div__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling / on a linear expression is not supported, " + "please use CpModel.add_division_equality"); + }) + .def("__truediv__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling // on a linear expression is not supported, " + "please use CpModel.add_division_equality"); + }) + .def("__mod__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling %% on a linear expression is not supported, " + "please use CpModel.add_modulo_equality"); + }) + .def("__pow__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling ** on a linear expression is not supported, " + "please use CpModel.add_multiplication_equality"); + }) + .def("__lshift__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling left shift on a linear expression is not " + "supported"); + }) + .def("__rshift__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling right shift on a linear expression is " + "not supported"); + }) + .def("__and__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling and on a linear expression is not supported"); + }) + .def("__or__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling or on a linear expression is not supported"); + }) + .def("__xor__", + [](std::shared_ptr /*self*/, py::handle /*other*/) { + ThrowError(PyExc_NotImplementedError, + "calling xor on a linear expression is not supported"); + }) + .def("__abs__", + [](std::shared_ptr /*self*/) { + ThrowError( + PyExc_NotImplementedError, + "calling abs() on a linear expression is not supported, " + "please use CpModel.add_abs_equality"); + }) + .def("__bool__", [](std::shared_ptr /*self*/) { + ThrowError(PyExc_NotImplementedError, + "Evaluating a LinearExpr instance as a Boolean is " + "not supported."); + }); + + // Expose Internal classes, mostly for testing. + py::class_, LinearExpr>( + m, "FlatFloatExpr", DOC(operations_research, sat, python, FlatFloatExpr)) + .def(py::init>()) + .def_property_readonly("vars", &FlatFloatExpr::vars) + .def_property_readonly("coeffs", &FlatFloatExpr::coeffs) + .def_property_readonly("offset", &FlatFloatExpr::offset); + + py::class_, LinearExpr>( + m, "SumArray", DOC(operations_research, sat, python, SumArray)) + .def( + py::init>, int64_t, double>()) + .def( + "__add__", + [](std::shared_ptr expr, + std::shared_ptr other) -> std::shared_ptr { + return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, Add)) + .def( + "__add__", + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); + }, + DOC(operations_research, sat, python, LinearExpr, AddInt)) + .def( + "__add__", + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, AddFloat)) + .def( + "__radd__", + [](std::shared_ptr expr, + std::shared_ptr other) -> std::shared_ptr { + return IsFree(expr) ? expr->AddInPlace(other) : expr->Add(other); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, Add)) + .def( + "__radd__", + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddIntInPlace(cst) : expr->AddInt(cst); + }, + py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddInt)) + .def( + "__radd__", + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddFloatInPlace(cst) + : expr->AddFloat(cst); + }, + py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, AddFloat)) + .def( + "__iadd__", + [](std::shared_ptr expr, + std::shared_ptr other) -> std::shared_ptr { + return expr->AddInPlace(other); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, Add)) + .def( + "__iadd__", + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + return expr->AddIntInPlace(cst); + }, + DOC(operations_research, sat, python, LinearExpr, AddInt)) + .def( + "__iadd__", + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return expr->AddFloatInPlace(cst); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, AddFloat)) + .def( + "__sub__", + [](std::shared_ptr expr, + std::shared_ptr other) -> std::shared_ptr { + return IsFree(expr) ? expr->AddInPlace(other->Neg()) + : expr->Sub(other); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, Sub)) + .def( + "__sub__", + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddIntInPlace(-cst) : expr->SubInt(cst); + }, + py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, SubInt)) + .def( + "__sub__", + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return IsFree(expr) ? expr->AddFloatInPlace(-cst) + : expr->SubFloat(cst); + }, + py::arg("cst"), + DOC(operations_research, sat, python, LinearExpr, SubFloat)) + .def( + "__isub__", + [](std::shared_ptr expr, + std::shared_ptr other) -> std::shared_ptr { + return expr->AddInPlace(other->Neg()); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, Sub)) + .def( + "__isub__", + [](std::shared_ptr expr, + int64_t cst) -> std::shared_ptr { + return expr->AddIntInPlace(-cst); + }, + DOC(operations_research, sat, python, LinearExpr, SubInt)) + .def( + "__isub__", + [](std::shared_ptr expr, + double cst) -> std::shared_ptr { + return expr->AddFloatInPlace(-cst); + }, + py::arg("other").none(false), + DOC(operations_research, sat, python, LinearExpr, SubFloat)) + .def_property_readonly("num_exprs", &SumArray::num_exprs) + .def_property_readonly("int_offset", &SumArray::int_offset) + .def_property_readonly("double_offset", &SumArray::double_offset); + + py::class_, LinearExpr>( + m, "FloatAffine", DOC(operations_research, sat, python, FloatAffine)) + .def(py::init, double, double>()) + .def_property_readonly("expression", &FloatAffine::expression) + .def_property_readonly("coefficient", &FloatAffine::coefficient) + .def_property_readonly("offset", &FloatAffine::offset); + + py::class_, LinearExpr>( + m, "IntAffine", DOC(operations_research, sat, python, IntAffine)) + .def(py::init, int64_t, int64_t>()) + .def_property_readonly("expression", &IntAffine::expression, + "Returns the linear expression.") + .def_property_readonly("coefficient", &IntAffine::coefficient, + "Returns the coefficient.") + .def_property_readonly("offset", &IntAffine::offset, + "Returns the offset."); + + py::class_, LinearExpr>( + m, "FlatIntExpr", DOC(operations_research, sat, python, FlatIntExpr)) + .def(py::init([](std::shared_ptr expr) { + FlatIntExpr* result = new FlatIntExpr(expr); + if (!result->ok()) { + ThrowError(PyExc_TypeError, + absl::StrCat("Tried to build a FlatIntExpr from a linear " + "expression with " + "floating point coefficients or constants: ", + expr->DebugString())); + } + return result; + })) + .def_property_readonly("vars", &FlatIntExpr::vars) + .def_property_readonly("coeffs", &FlatIntExpr::coeffs) + .def_property_readonly("offset", &FlatIntExpr::offset) + .def_property_readonly("ok", &FlatIntExpr::ok); + + py::class_, LinearExpr>( + m, "Literal", DOC(operations_research, sat, python, Literal)) + .def_property_readonly( + "index", &Literal::index, + DOC(operations_research, sat, python, Literal, index)) + .def("negated", &Literal::negated, + DOC(operations_research, sat, python, Literal, negated)) + .def("__invert__", &Literal::negated, + DOC(operations_research, sat, python, Literal, negated)) + .def("__bool__", + [](std::shared_ptr /*self*/) { + ThrowError(PyExc_NotImplementedError, + "Evaluating a Literal as a Boolean value is " + "not supported."); + }) + .def("__hash__", &Literal::Hash) + // Pre PEP8 compatibility layer. + .def("Not", &Literal::negated) + .def("Index", &Literal::index); + + // IntVar and NotBooleanVariable both hold a shared_ptr to the model_proto. + py::class_, Literal>( + m, "IntVar", DOC(operations_research, sat, python, IntVar)) + .def(py::init, int>()) + .def(py::init>()) // new variable. + .def_property_readonly( + "proto", &IntVar::proto, py::return_value_policy::reference_internal, + "Returns the IntegerVariableProto of this variable.") + .def_property_readonly("model_proto", &IntVar::model_proto, + "Returns the CP model protobuf") + .def_property_readonly( + "index", &IntVar::index, py::return_value_policy::reference, + DOC(operations_research, sat, python, IntVar, index)) + .def_property_readonly( + "is_boolean", &IntVar::is_boolean, + DOC(operations_research, sat, python, IntVar, is_boolean)) + .def_property("name", &IntVar::name, &IntVar::SetName, + "The name of the variable.") + .def( + "with_name", + [](std::shared_ptr self, const std::string& name) { + self->SetName(name); + return self; + }, + py::arg("name"), + "Sets the name of the variable and returns the variable.") + .def_property("domain", &IntVar::domain, &IntVar::SetDomain, + "The domain of the variable.") + .def( + "with_domain", + [](std::shared_ptr self, const Domain& domain) { + self->SetDomain(domain); + return self; + }, + py::arg("domain"), + "Sets the domain of the variable and returns the variable.") + .def("__str__", &IntVar::ToString) + .def("__repr__", &IntVar::DebugString) + .def( + "negated", + [](std::shared_ptr self) { + if (!self->is_boolean()) { + ThrowError(PyExc_TypeError, + "negated() is only supported for Boolean variables."); + } + return self->negated(); + }, + DOC(operations_research, sat, python, IntVar, negated)) + .def( + "__invert__", + [](std::shared_ptr self) { + if (!self->is_boolean()) { + ThrowError(PyExc_TypeError, + "negated() is only supported for Boolean variables."); + } + return self->negated(); + }, + DOC(operations_research, sat, python, IntVar, negated)) + .def("__copy__", + [](const std::shared_ptr& self) { + return std::make_shared(self->model_proto(), + self->index()); + }) + .def(py::pickle( + [](std::shared_ptr p) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(p->model_proto(), p->index()); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 2) throw std::runtime_error("Invalid state!"); + + return std::make_shared( + t[0].cast>(), t[1].cast()); + })) + // Pre PEP8 compatibility layer. + .def("Name", &IntVar::name) + .def("Proto", &IntVar::proto, py::return_value_policy::reference, + py::keep_alive<1, 0>(), + "Returns the IntegerVariableProto of this variable.") + .def("Not", + [](std::shared_ptr self) { + if (!self->is_boolean()) { + ThrowError(PyExc_TypeError, + "negated() is only supported for Boolean variables."); + } + return self->negated(); + }) + .def("Index", &IntVar::index); + + py::class_, Literal>( + m, "NotBooleanVariable", + DOC(operations_research, sat, python, NotBooleanVariable)) + .def_property_readonly( + "index", + [](std::shared_ptr not_var) -> int { + return not_var->index(); + }, + DOC(operations_research, sat, python, NotBooleanVariable, index)) + .def("__str__", + [](std::shared_ptr not_var) -> std::string { + return not_var->ToString(); + }) + .def("__repr__", + [](std::shared_ptr not_var) -> std::string { + return not_var->DebugString(); + }) + .def( + "negated", + [](std::shared_ptr not_var) + -> std::shared_ptr { return not_var->negated(); }, + DOC(operations_research, sat, python, NotBooleanVariable, negated)) + .def( + "__invert__", + [](std::shared_ptr not_var) + -> std::shared_ptr { return not_var->negated(); }, + DOC(operations_research, sat, python, NotBooleanVariable, negated)) + // Pre PEP8 compatibility layer. + .def( + "Not", + [](std::shared_ptr not_var) + -> std::shared_ptr { return not_var->negated(); }, + DOC(operations_research, sat, python, NotBooleanVariable, negated)); + + py::class_>( + m, "BoundedLinearExpression", + DOC(operations_research, sat, python, BoundedLinearExpression)) + .def(py::init, Domain>()) + .def(py::init, + const std::shared_ptr, Domain>()) + .def_property_readonly("bounds", &BoundedLinearExpression::bounds) + .def_property_readonly("vars", &BoundedLinearExpression::vars) + .def_property_readonly("coeffs", &BoundedLinearExpression::coeffs) + .def_property_readonly("offset", &BoundedLinearExpression::offset) + .def_property_readonly("ok", &BoundedLinearExpression::ok) + .def("__str__", &BoundedLinearExpression::ToString) + .def("__repr__", &BoundedLinearExpression::DebugString) + .def("__bool__", [](const BoundedLinearExpression& self) { + bool result; + if (self.CastToBool(&result)) return result; + ThrowError(PyExc_NotImplementedError, + absl::StrCat("Evaluating a BoundedLinearExpression '", + self.ToString(), + "'instance as a Boolean is " + "not supported.")); + return false; + }); +} // NOLINT(readability/fn_size) + +} // namespace operations_research::sat::python diff --git a/ortools/sat/python/pybind_linearexpr.h b/ortools/sat/python/pybind_linearexpr.h new file mode 100644 index 00000000000..0365adc25ce --- /dev/null +++ b/ortools/sat/python/pybind_linearexpr.h @@ -0,0 +1,25 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ORTOOLS_SAT_PYTHON_PYBIND_LINEAREXPR_H_ +#define ORTOOLS_SAT_PYTHON_PYBIND_LINEAREXPR_H_ + +#include "pybind11/pybind11.h" + +namespace operations_research::sat::python { + +void DefinePybindWrapperForLinearExpr(pybind11::module& m); + +} // namespace operations_research::sat::python + +#endif // ORTOOLS_SAT_PYTHON_PYBIND_LINEAREXPR_H_ diff --git a/ortools/sat/python/pybind_solver.cc b/ortools/sat/python/pybind_solver.cc new file mode 100644 index 00000000000..daa803e90ee --- /dev/null +++ b/ortools/sat/python/pybind_solver.cc @@ -0,0 +1,284 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/python/pybind_solver.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/python/linear_expr.h" +#include "ortools/sat/swig_helper.h" +#include "pybind11/functional.h" +#include "pybind11/gil.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace operations_research::sat::python { + +namespace { + +namespace py = pybind11; + +void ThrowError(PyObject* py_exception, const std::string& message) { + PyErr_SetString(py_exception, message.c_str()); + throw py::error_already_set(); +} + +// We extend the SolverWrapper class to keep track of the local error already +// set. +class ExtSolveWrapper : public SolveWrapper { + public: + mutable std::optional local_error_already_set_; +}; + +// A trampoline class to override the OnSolutionCallback method to acquire the +// GIL. +class PySolutionCallback : public SolutionCallback { + public: + using SolutionCallback::SolutionCallback; /* Inherit constructors */ + void OnSolutionCallback() const override { + py::gil_scoped_acquire acquire; + try { + PYBIND11_OVERRIDE_PURE( + void, /* Return type */ + SolutionCallback, /* Parent class */ + OnSolutionCallback, /* Name of function */ + /* This function has no arguments. The trailing comma + in the previous line is needed for some compilers */ + ); + } catch (py::error_already_set& e) { + // We assume this code is serialized as the gil is held. + ExtSolveWrapper* solve_wrapper = static_cast(wrapper()); + if (!solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_ = e; + } + StopSearch(); + } + } +}; + +class ResponseHelper { + public: + static bool BooleanValue(std::shared_ptr response, + std::shared_ptr lit) { + const int index = lit->index(); + if (index >= 0) { + return response->solution(index) != 0; + } else { + return response->solution(NegatedRef(index)) == 0; + } + } + + static bool FixedBooleanValue(std::shared_ptr, bool lit) { + return lit; + } + + static std::vector SufficientAssumptionsForInfeasibility( + std::shared_ptr response) { + return std::vector( + response->sufficient_assumptions_for_infeasibility().begin(), + response->sufficient_assumptions_for_infeasibility().end()); + } + + static double FloatValue(std::shared_ptr response, + std::shared_ptr expr) { + FloatExprVisitor visitor; + visitor.AddToProcess(expr, 1); + return visitor.Evaluate(*response); + } + + static double FixedFloatValue(std::shared_ptr, + double value) { + return value; + } + + static int64_t Value(std::shared_ptr response, + std::shared_ptr expr) { + int64_t value; + IntExprVisitor visitor; + visitor.AddToProcess(expr, 1); + if (!visitor.Evaluate(*response, &value)) { + ThrowError(PyExc_ValueError, + absl::StrCat("Failed to evaluate linear expression: ", + expr->DebugString())); + } + return value; + } + + static int64_t FixedValue(std::shared_ptr, int64_t value) { + return value; + } +}; + +} // namespace + +void DefinePybindWrapperForSolver(py::module& m) { + // We keep the CamelCase name for the SolutionCallback class to be + // compatible with the pre PEP8 python code. + py::class_(m, "SolutionCallback") + .def(py::init<>()) + .def("OnSolutionCallback", &SolutionCallback::OnSolutionCallback) + .def("BestObjectiveBound", &SolutionCallback::BestObjectiveBound) + .def("DeterministicTime", &SolutionCallback::DeterministicTime) + .def("HasResponse", &SolutionCallback::HasResponse) + .def("NumBinaryPropagations", &SolutionCallback::NumBinaryPropagations) + .def("NumBooleans", &SolutionCallback::NumBooleans) + .def("NumBranches", &SolutionCallback::NumBranches) + .def("NumConflicts", &SolutionCallback::NumConflicts) + .def("NumIntegerPropagations", &SolutionCallback::NumIntegerPropagations) + .def("ObjectiveValue", &SolutionCallback::ObjectiveValue) + .def("Response", &SolutionCallback::SharedResponse) + .def("SolutionBooleanValue", &SolutionCallback::SolutionBooleanValue, + py::arg("index")) + .def("SolutionIntegerValue", &SolutionCallback::SolutionIntegerValue, + py::arg("index")) + .def("StopSearch", &SolutionCallback::StopSearch) + .def("UserTime", &SolutionCallback::UserTime) + .def("WallTime", &SolutionCallback::WallTime) + .def( + "Value", + [](const SolutionCallback& self, std::shared_ptr expr) { + return ResponseHelper::Value(self.SharedResponse(), expr); + }, + "Returns the value of a linear expression after solve.") + .def( + "Value", [](const SolutionCallback&, int64_t value) { return value; }, + "Returns the value of a linear expression after solve.") + .def( + "FloatValue", + [](const SolutionCallback& self, std::shared_ptr expr) { + return ResponseHelper::FloatValue(self.SharedResponse(), expr); + }, + "Returns the value of a floating point linear expression after " + "solve.") + .def( + "FloatValue", + [](const SolutionCallback&, double value) { return value; }, + "Returns the value of a floating point linear expression after " + "solve.") + .def( + "BooleanValue", + [](const SolutionCallback& self, std::shared_ptr lit) { + return ResponseHelper::BooleanValue(self.SharedResponse(), lit); + }, + "Returns the Boolean value of a literal after solve.") + .def( + "BooleanValue", [](const SolutionCallback&, bool lit) { return lit; }, + "Returns the Boolean value of a literal after solve."); + + py::class_(m, "ResponseHelper") + .def_static("boolean_value", &ResponseHelper::BooleanValue, + py::arg("response").none(false), py::arg("lit").none(false)) + .def_static("boolean_value", &ResponseHelper::FixedBooleanValue, + py::arg("response").none(false), py::arg("lit").none(false)) + .def_static("float_value", &ResponseHelper::FloatValue, + py::arg("response").none(false), py::arg("expr").none(false)) + .def_static("float_value", &ResponseHelper::FixedFloatValue, + py::arg("response").none(false), py::arg("value").none(false)) + .def_static("sufficient_assumptions_for_infeasibility", + &ResponseHelper::SufficientAssumptionsForInfeasibility, + py::arg("response").none(false)) + .def_static("value", &ResponseHelper::Value, + py::arg("response").none(false), py::arg("expr").none(false)) + .def_static("value", &ResponseHelper::FixedValue, + py::arg("response").none(false), + py::arg("value").none(false)); + + py::class_(m, "SolveWrapper") + .def(py::init<>()) + .def( + "add_log_callback", + [](ExtSolveWrapper* solve_wrapper, + std::function log_callback) { + std::function safe_log_callback = + [solve_wrapper, log_callback](std::string message) -> void { + py::gil_scoped_acquire acquire; + try { + log_callback(message); + } catch (py::error_already_set& e) { + // We assume this code is serialized as the gil is held. + if (!solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_ = e; + } + solve_wrapper->StopSearch(); + } + }; + solve_wrapper->AddLogCallback(safe_log_callback); + }, + py::arg("log_callback").none(false)) + .def("add_solution_callback", &SolveWrapper::AddSolutionCallback, + py::arg("callback")) + .def("clear_solution_callback", &SolveWrapper::ClearSolutionCallback) + .def( + "add_best_bound_callback", + [](ExtSolveWrapper* solve_wrapper, + std::function best_bound_callback) { + std::function safe_best_bound_callback = + [solve_wrapper, best_bound_callback](double bound) -> void { + py::gil_scoped_acquire acquire; + try { + best_bound_callback(bound); + } catch (py::error_already_set& e) { + // We assume this code is serialized as the gil is held. + if (!solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_ = e; + } + solve_wrapper->StopSearch(); + } + }; + solve_wrapper->AddBestBoundCallback(safe_best_bound_callback); + }, + py::arg("best_bound_callback").none(false)) + .def( + "set_parameters", + [](ExtSolveWrapper* solve_wrapper, + std::shared_ptr parameters) { + solve_wrapper->SetParameters(*parameters); + }, + py::arg("parameters").none(false)) + .def( + "solve", + [](ExtSolveWrapper* solve_wrapper, + std::shared_ptr model_proto) -> CpSolverResponse { + const auto result = [=]() -> CpSolverResponse { + py::gil_scoped_release release; + return solve_wrapper->Solve(*model_proto); + }(); + if (solve_wrapper->local_error_already_set_.has_value()) { + solve_wrapper->local_error_already_set_->restore(); + solve_wrapper->local_error_already_set_.reset(); + throw py::error_already_set(); + } + return result; + }, + py::arg("model_proto").none(false)) + .def("stop_search", &SolveWrapper::StopSearch); + + py::class_(m, "CpSatHelper") + .def_static("model_stats", &CpSatHelper::ModelStats, + py::arg("model_proto")) + .def_static("solver_response_stats", &CpSatHelper::SolverResponseStats, + py::arg("response")) + .def_static("validate_model", &CpSatHelper::ValidateModel, + py::arg("model_proto")) + .def_static("write_model_to_file", &CpSatHelper::WriteModelToFile, + py::arg("model_proto"), py::arg("filename")); +} + +} // namespace operations_research::sat::python diff --git a/ortools/sat/python/pybind_solver.h b/ortools/sat/python/pybind_solver.h new file mode 100644 index 00000000000..55bf2bb691a --- /dev/null +++ b/ortools/sat/python/pybind_solver.h @@ -0,0 +1,25 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ORTOOLS_SAT_PYTHON_PYBIND_SOLVER_H_ +#define ORTOOLS_SAT_PYTHON_PYBIND_SOLVER_H_ + +#include "pybind11/pybind11.h" + +namespace operations_research::sat::python { + +void DefinePybindWrapperForSolver(pybind11::module& m); + +} // namespace operations_research::sat::python + +#endif // ORTOOLS_SAT_PYTHON_PYBIND_SOLVER_H_ diff --git a/ortools/sat/restart.cc b/ortools/sat/restart.cc index e0e0d069b19..c1839b4098b 100644 --- a/ortools/sat/restart.cc +++ b/ortools/sat/restart.cc @@ -19,6 +19,7 @@ #include "absl/log/log.h" #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "ortools/port/proto_utils.h" #include "ortools/sat/sat_decision.h" #include "ortools/sat/sat_parameters.pb.h" @@ -50,9 +51,9 @@ void RestartPolicy::Reset() { strategies_.push_back(parameters_.restart_algorithms(i)); } if (strategies_.empty()) { - const std::vector string_values = absl::StrSplit( + const std::vector string_values = absl::StrSplit( parameters_.default_restart_algorithms(), ',', absl::SkipEmpty()); - for (const std::string& string_value : string_values) { + for (const absl::string_view string_value : string_values) { SatParameters::RestartAlgorithm tmp; if (!SatParameters::RestartAlgorithm_Parse(string_value, &tmp)) { LOG(WARNING) << "Couldn't parse the RestartAlgorithm name: '" diff --git a/ortools/sat/routing_cuts.cc b/ortools/sat/routing_cuts.cc index c05d9d1d3cf..93e1dd60bba 100644 --- a/ortools/sat/routing_cuts.cc +++ b/ortools/sat/routing_cuts.cc @@ -2932,13 +2932,13 @@ void ExtractAllSubsetsFromForest(absl::Span parent, subsets->clear(); // Starts by creating the corresponding graph and find the root. - util::StaticGraph graph(num_nodes, num_nodes - 1); + util::StaticGraph<>::Builder builder(num_nodes, num_nodes - 1); for (int i = 0; i < num_nodes; ++i) { if (parent[i] != i) { - graph.AddArc(parent[i], i); + builder.AddArc(parent[i], i); } } - graph.Build(); + const auto graph = std::move(builder).BuildGraph(nullptr); // Perform a dfs on the rooted tree. // The subset_data will just be the node in post-order. diff --git a/ortools/sat/routing_cuts.h b/ortools/sat/routing_cuts.h index f9bca8b915c..11d1704c1fb 100644 --- a/ortools/sat/routing_cuts.h +++ b/ortools/sat/routing_cuts.h @@ -30,6 +30,7 @@ #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cuts.h" #include "ortools/sat/integer.h" @@ -666,10 +667,10 @@ void GenerateInterestingSubsets(int num_nodes, // // TODO(user): This also allocate O(n) memory internally, we could reuse it from // call to call if needed. -void ExtractAllSubsetsFromForest( - absl::Span parent, std::vector* subset_data, - std::vector>* subsets, - int node_limit = std::numeric_limits::max()); +void ExtractAllSubsetsFromForest(absl::Span parent, + std::vector* subset_data, + std::vector>* subsets, + int node_limit = kint32max); // In the routing context, we usually always have lp_value in [0, 1] and only // looks at arcs with a lp_value that is not too close to zero. diff --git a/ortools/sat/samples/assignment_groups_sat.py b/ortools/sat/samples/assignment_groups_sat.py index 2ea4fc973b0..ba620317028 100644 --- a/ortools/sat/samples/assignment_groups_sat.py +++ b/ortools/sat/samples/assignment_groups_sat.py @@ -14,6 +14,7 @@ # [START program] """Solves an assignment problem for given group of workers.""" + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/assignment_task_sizes_sat.py b/ortools/sat/samples/assignment_task_sizes_sat.py index 0baca4a6df0..56abbc96635 100644 --- a/ortools/sat/samples/assignment_task_sizes_sat.py +++ b/ortools/sat/samples/assignment_task_sizes_sat.py @@ -14,6 +14,7 @@ # [START program] """Solves a simple assignment problem.""" + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/assignment_teams_sat.py b/ortools/sat/samples/assignment_teams_sat.py index 375087e856c..f350ce435d0 100644 --- a/ortools/sat/samples/assignment_teams_sat.py +++ b/ortools/sat/samples/assignment_teams_sat.py @@ -14,6 +14,7 @@ # [START program] """Solves a simple assignment problem.""" + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/assumptions_sample_sat.py b/ortools/sat/samples/assumptions_sample_sat.py index 62501b9b2f3..b833627ab51 100644 --- a/ortools/sat/samples/assumptions_sample_sat.py +++ b/ortools/sat/samples/assumptions_sample_sat.py @@ -13,6 +13,7 @@ # limitations under the License. """Code sample that solves a model and gets the infeasibility assumptions.""" + # [START program] # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/binpacking_problem_sat.py b/ortools/sat/samples/binpacking_problem_sat.py index 29bbe84da17..9dabf1f2028 100644 --- a/ortools/sat/samples/binpacking_problem_sat.py +++ b/ortools/sat/samples/binpacking_problem_sat.py @@ -15,7 +15,6 @@ # [START program] """Solves a binpacking problem using the CP-SAT solver.""" - from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/bool_or_sample_sat.py b/ortools/sat/samples/bool_or_sample_sat.py index b03c75066ef..47cb207168f 100644 --- a/ortools/sat/samples/bool_or_sample_sat.py +++ b/ortools/sat/samples/bool_or_sample_sat.py @@ -15,7 +15,6 @@ # [START program] """Code sample to demonstrates a simple Boolean constraint.""" - from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/boolean_product_sample_sat.py b/ortools/sat/samples/boolean_product_sample_sat.py index c815db67c61..efd4ec35e68 100644 --- a/ortools/sat/samples/boolean_product_sample_sat.py +++ b/ortools/sat/samples/boolean_product_sample_sat.py @@ -15,7 +15,6 @@ # [START program] """Code sample that encodes the product of two Boolean variables.""" - from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/channeling_sample_sat.py b/ortools/sat/samples/channeling_sample_sat.py index 805b261d573..c37a13ef72d 100644 --- a/ortools/sat/samples/channeling_sample_sat.py +++ b/ortools/sat/samples/channeling_sample_sat.py @@ -15,7 +15,6 @@ # [START program] """Link integer constraints together.""" - from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/cp_is_fun_sat.py b/ortools/sat/samples/cp_is_fun_sat.py index 7a8aeaedc02..2f1ebf71004 100644 --- a/ortools/sat/samples/cp_is_fun_sat.py +++ b/ortools/sat/samples/cp_is_fun_sat.py @@ -20,6 +20,7 @@ This problem has 72 different solutions in base 10. """ + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/cp_sat_example.py b/ortools/sat/samples/cp_sat_example.py index f7f68b23659..110f20585f7 100755 --- a/ortools/sat/samples/cp_sat_example.py +++ b/ortools/sat/samples/cp_sat_example.py @@ -14,6 +14,7 @@ # [START program] """Simple solve.""" + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.py b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.py index c8a1fbe6f59..2e8bcab06b9 100644 --- a/ortools/sat/samples/earliness_tardiness_cost_sample_sat.py +++ b/ortools/sat/samples/earliness_tardiness_cost_sample_sat.py @@ -15,7 +15,6 @@ # [START program] """Encodes a convex piecewise linear function.""" - from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/interval_sample_sat.py b/ortools/sat/samples/interval_sample_sat.py index a42e8e6551a..3aaf5d327d7 100644 --- a/ortools/sat/samples/interval_sample_sat.py +++ b/ortools/sat/samples/interval_sample_sat.py @@ -15,7 +15,6 @@ # [START program] """Code sample to demonstrates how to build an interval.""" - from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/literal_sample_sat.py b/ortools/sat/samples/literal_sample_sat.py index 809dcab8e37..56a485b5113 100644 --- a/ortools/sat/samples/literal_sample_sat.py +++ b/ortools/sat/samples/literal_sample_sat.py @@ -15,7 +15,6 @@ # [START program] """Code sample to demonstrate Boolean variable and literals.""" - from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/minimal_jobshop_sat.py b/ortools/sat/samples/minimal_jobshop_sat.py index 6690d43ea05..a70de8454ad 100644 --- a/ortools/sat/samples/minimal_jobshop_sat.py +++ b/ortools/sat/samples/minimal_jobshop_sat.py @@ -14,6 +14,7 @@ # [START program] """Minimal jobshop example.""" + # [START import] import collections diff --git a/ortools/sat/samples/multiple_knapsack_sat.py b/ortools/sat/samples/multiple_knapsack_sat.py index 3f3b3e567a7..9e1d2febb58 100644 --- a/ortools/sat/samples/multiple_knapsack_sat.py +++ b/ortools/sat/samples/multiple_knapsack_sat.py @@ -14,6 +14,7 @@ # [START program] """Solves a multiple knapsack problem using the CP-SAT solver.""" + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/nqueens_sat.py b/ortools/sat/samples/nqueens_sat.py index de6e6d97f2a..06e5bbcfb01 100644 --- a/ortools/sat/samples/nqueens_sat.py +++ b/ortools/sat/samples/nqueens_sat.py @@ -14,6 +14,7 @@ # [START program] """OR-Tools solution to the N-queens problem.""" + # [START import] import sys import time diff --git a/ortools/sat/samples/nurses_sat.py b/ortools/sat/samples/nurses_sat.py index 16fae1af178..9095d6e6aff 100644 --- a/ortools/sat/samples/nurses_sat.py +++ b/ortools/sat/samples/nurses_sat.py @@ -14,6 +14,7 @@ # [START program] """Example of a simple nurse scheduling problem.""" + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/ranking_circuit_sample_sat.py b/ortools/sat/samples/ranking_circuit_sample_sat.py index 6a93bf82b07..70828f9f297 100644 --- a/ortools/sat/samples/ranking_circuit_sample_sat.py +++ b/ortools/sat/samples/ranking_circuit_sample_sat.py @@ -15,7 +15,6 @@ # [START program] """Code sample to demonstrates how to rank intervals using a circuit.""" - from collections.abc import Sequence from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/schedule_requests_sat.py b/ortools/sat/samples/schedule_requests_sat.py index f89e5475a98..a362725cd0d 100644 --- a/ortools/sat/samples/schedule_requests_sat.py +++ b/ortools/sat/samples/schedule_requests_sat.py @@ -14,6 +14,7 @@ # [START program] """Nurse scheduling problem with shift requests.""" + # [START import] from typing import Union diff --git a/ortools/sat/samples/scheduling_with_calendar_sample_sat.py b/ortools/sat/samples/scheduling_with_calendar_sample_sat.py index 6b76b92a8a9..ff6d9b88842 100644 --- a/ortools/sat/samples/scheduling_with_calendar_sample_sat.py +++ b/ortools/sat/samples/scheduling_with_calendar_sample_sat.py @@ -44,7 +44,7 @@ def scheduling_with_calendar_sample_sat(): # Because of the break, work cannot start at 13h. start = model.new_int_var_from_domain( - cp_model.Domain.from_intervals([(8, 12), (14, 15)]), "start" + cp_model.Domain.from_intervals([[8, 12], [14, 15]]), "start" ) duration = model.new_int_var(3, 4, "duration") end = model.new_int_var(8, 18, "end") diff --git a/ortools/sat/samples/simple_sat_program.py b/ortools/sat/samples/simple_sat_program.py index 3c2041c6cf9..796f85950f2 100644 --- a/ortools/sat/samples/simple_sat_program.py +++ b/ortools/sat/samples/simple_sat_program.py @@ -14,6 +14,7 @@ # [START program] """Simple solve.""" + # [START import] from ortools.sat.python import cp_model diff --git a/ortools/sat/samples/step_function_sample_sat.py b/ortools/sat/samples/step_function_sample_sat.py index 1f0f174a333..ffd6e0f39b0 100644 --- a/ortools/sat/samples/step_function_sample_sat.py +++ b/ortools/sat/samples/step_function_sample_sat.py @@ -54,14 +54,14 @@ def step_function_sample_sat(): # expr == 0 on [5, 6] U [8, 10] b0 = model.new_bool_var("b0") model.add_linear_expression_in_domain( - x, cp_model.Domain.from_intervals([(5, 6), (8, 10)]) + x, cp_model.Domain.from_intervals([[5, 6], [8, 10]]) ).only_enforce_if(b0) model.add(expr == 0).only_enforce_if(b0) # expr == 2 on [0, 1] U [3, 4] U [11, 20] b2 = model.new_bool_var("b2") model.add_linear_expression_in_domain( - x, cp_model.Domain.from_intervals([(0, 1), (3, 4), (11, 20)]) + x, cp_model.Domain.from_intervals([[0, 1], [3, 4], [11, 20]]) ).only_enforce_if(b2) model.add(expr == 2).only_enforce_if(b2) diff --git a/ortools/sat/sat_decision.cc b/ortools/sat/sat_decision.cc index 8f4aedb22e6..3038673d8eb 100644 --- a/ortools/sat/sat_decision.cc +++ b/ortools/sat/sat_decision.cc @@ -43,7 +43,7 @@ namespace sat { SatDecisionPolicy::SatDecisionPolicy(Model* model) : parameters_(*(model->GetOrCreate())), trail_(*model->GetOrCreate()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), ls_hints_(model->GetOrCreate()) {} void SatDecisionPolicy::IncreaseNumVariables(int num_variables) { @@ -185,7 +185,7 @@ void SatDecisionPolicy::ResetInitialPolarity(int from, bool inverted) { var_polarity_[var] = trail_.Assignment().LiteralIsTrue(Literal(var, true)); } else { - var_polarity_[var] = absl::Bernoulli(*random_, 0.5); + var_polarity_[var] = absl::Bernoulli(random_, 0.5); } break; } @@ -211,7 +211,7 @@ bool SatDecisionPolicy::UseLsSolutionAsInitialPolarity() { // TODO(user): use cp_model_mapping. But this is not needed to experiment // on pure sat problems. std::shared_ptr solution = - ls_hints_->GetRandomBiasedSolution(*random_); + ls_hints_->GetRandomBiasedSolution(random_); if (solution->variable_values.size() != var_polarity_.size()) return false; for (int i = 0; i < solution->variable_values.size(); ++i) { @@ -231,7 +231,7 @@ void SatDecisionPolicy::FlipCurrentPolarity() { void SatDecisionPolicy::RandomizeCurrentPolarity() { const int num_variables = var_polarity_.size(); for (BooleanVariable var; var < num_variables; ++var) { - var_polarity_[var] = std::uniform_int_distribution(0, 1)(*random_); + var_polarity_[var] = std::uniform_int_distribution(0, 1)(random_); } } @@ -286,7 +286,7 @@ void SatDecisionPolicy::InitializeVariableOrdering() { std::reverse(tmp_variables_.begin(), tmp_variables_.end()); break; case SatParameters::IN_RANDOM_ORDER: - std::shuffle(tmp_variables_.begin(), tmp_variables_.end(), *random_); + std::shuffle(tmp_variables_.begin(), tmp_variables_.end(), random_); break; } @@ -391,7 +391,7 @@ Literal SatDecisionPolicy::NextBranch() { BooleanVariable var; const double ratio = parameters_.random_branches_ratio(); auto zero_to_one = [this]() { - return std::uniform_real_distribution()(*random_); + return std::uniform_real_distribution()(random_); }; if (ratio != 0.0 && zero_to_one() < ratio) { while (true) { @@ -399,7 +399,7 @@ Literal SatDecisionPolicy::NextBranch() { // variables are assigned. std::uniform_int_distribution index_dist(0, var_ordering_.Size() - 1); - var = var_ordering_.QueueElement(index_dist(*random_)).var; + var = var_ordering_.QueueElement(index_dist(random_)).var; if (!trail_.Assignment().VariableIsAssigned(var)) break; pq_need_update_for_var_at_trail_index_.Set(trail_.Info(var).trail_index); var_ordering_.Remove(var.value()); @@ -419,7 +419,7 @@ Literal SatDecisionPolicy::NextBranch() { // Choose its polarity (i.e. True of False). const double random_ratio = parameters_.random_polarity_ratio(); if (random_ratio != 0.0 && zero_to_one() < random_ratio) { - return Literal(var, std::uniform_int_distribution(0, 1)(*random_)); + return Literal(var, std::uniform_int_distribution(0, 1)(random_)); } if (has_forced_polarity_[var]) return Literal(var, forced_polarity_[var]); diff --git a/ortools/sat/sat_decision.h b/ortools/sat/sat_decision.h index 6be188035e0..7de1b671ed0 100644 --- a/ortools/sat/sat_decision.h +++ b/ortools/sat/sat_decision.h @@ -161,7 +161,7 @@ class SatDecisionPolicy { // Singleton model objects. const SatParameters& parameters_; const Trail& trail_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; // TODO(user): This is in term of proto indices. Ideally we would need // CpModelMapping to map that to Booleans but this currently lead to cyclic diff --git a/ortools/sat/sat_inprocessing.cc b/ortools/sat/sat_inprocessing.cc index fc20333e4c5..71183e94dad 100644 --- a/ortools/sat/sat_inprocessing.cc +++ b/ortools/sat/sat_inprocessing.cc @@ -35,6 +35,7 @@ #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" #include "ortools/base/timer.h" +#include "ortools/base/types.h" #include "ortools/sat/clause.h" #include "ortools/sat/drat_checker.h" #include "ortools/sat/linear_programming_constraint.h" @@ -544,10 +545,15 @@ bool Inprocessing::RemoveFixedAndEquivalentVariables(bool log_info) { if (!implication_graph_->RemoveDuplicatesAndFixedVariables()) return false; } - // Invariant. There should be no clause with fixed variables left. + // Invariant. There should be no clause with fixed or redundant variables + // left. if (DEBUG_MODE) { + const auto& assignment = trail_->Assignment(); for (SatClause* clause : clause_manager_->AllClausesInCreationOrder()) { - CHECK(!SomeLiteralAreAssigned(trail_->Assignment(), clause->AsSpan())); + for (const Literal lit : clause->AsSpan()) { + CHECK(!assignment.LiteralIsAssigned(lit)); + CHECK(!implication_graph_->IsRedundant(lit)); + } } } @@ -778,7 +784,7 @@ bool Inprocessing::SubsumeAndStrenghtenRound(bool log_info) { // Important: we can only use this clause to subsume/strenghten others if // it cannot be deleted later. if (!clause_manager_->IsRemovable(clause)) { - int min_size = std::numeric_limits::max(); + int min_size = kint32max; LiteralIndex min_literal = kNoLiteralIndex; for (const Literal l : clause->AsSpan()) { if (one_watcher[l].size() < min_size) { diff --git a/ortools/sat/sat_inprocessing.h b/ortools/sat/sat_inprocessing.h index 2f297231913..450b3a2a8c6 100644 --- a/ortools/sat/sat_inprocessing.h +++ b/ortools/sat/sat_inprocessing.h @@ -213,7 +213,7 @@ class StampingSimplifier { implication_graph_(model->GetOrCreate()), clause_manager_(model->GetOrCreate()), lrat_proof_handler_(model->Mutable()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), time_limit_(model->GetOrCreate()) {} // This is "fast" (linear scan + sort of all clauses) so we always complete @@ -256,7 +256,7 @@ class StampingSimplifier { BinaryImplicationGraph* implication_graph_; ClauseManager* clause_manager_; LratProofHandler* lrat_proof_handler_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; TimeLimit* time_limit_; // For ComputeStampsForNextRound(). diff --git a/ortools/sat/sat_parameters.proto b/ortools/sat/sat_parameters.proto index c3a4521b0a8..3a58f1f19e0 100644 --- a/ortools/sat/sat_parameters.proto +++ b/ortools/sat/sat_parameters.proto @@ -26,7 +26,7 @@ option java_multiple_files = true; // Contains the definitions for all the sat algorithm parameters and their // default values. // -// NEXT TAG: 362 +// NEXT TAG: 364 message SatParameters { // In some context, like in a portfolio of search, it makes sense to name a // given parameters set for logging purpose. @@ -697,6 +697,12 @@ message SatParameters { // Note that we only have a basic version currently. optional bool inprocessing_use_congruence_closure = 342 [default = true]; + // Experimental. We don't have LRAT support for this. Recovers a "circuit" + // structure based on the gate detection used by congruence closure, and try + // to detect more equivalence via sampling. This is like the custom SAT + // sweeping done on a logic circuit with binary gates. + optional bool inprocessing_detect_and_sweep_circuit = 362 [default = false]; + // Whether we use the SAT sweeping algorithm described in "Clausal Equivalence // Sweeping", Armin Biere, Katalin Fazekas, Mathias Fleury, Nils Froleyks, // 2025. @@ -1200,6 +1206,10 @@ message SatParameters { // Conflict limit used in the phase that exploit the solution hint. optional int32 hint_conflict_limit = 153 [default = 10]; + // Try first to use the hint as assumption in the phase that exploit the + // solution hint. This should be way faster if the hint is completable. + optional bool try_hint_as_assumptions = 363 [default = true]; + // If true, the solver tries to repair the solution given in the hint. This // search terminates after the 'hint_conflict_limit' is reached and the solver // switches to regular search. If false, then we do a FIXED_SEARCH using the diff --git a/ortools/sat/sat_runner.cc b/ortools/sat/sat_runner.cc index 25f0484f423..c60c6258cf0 100644 --- a/ortools/sat/sat_runner.cc +++ b/ortools/sat/sat_runner.cc @@ -15,7 +15,6 @@ #include #include #include -#include #include #include #include @@ -60,11 +59,11 @@ ABSL_FLAG( ABSL_FLAG( std::string, hint_file, "", - "Protobuf file containing a CpModelResponse. The solution will be used as a" - " hint to bootstrap the search."); + "Protobuf file containing a CpSolverResponse. The solution will be used as " + "a hint to bootstrap the search."); ABSL_FLAG(std::string, domain_file, "", - "Protobuf file containing a CpModelResponse. If present, the " + "Protobuf file containing a CpSolverResponse. If present, the " "tightened models will be used to reduce the domain of variables."); ABSL_FLAG(std::string, output, "", @@ -276,8 +275,8 @@ bool LoadProblem(const std::string& filename, absl::string_view hint_file, // This way, the solver will return MODEL_INVALID instead of // crashing. IntegerVariableProto* var = cp_model->add_variables(); - var->add_domain(std::numeric_limits::min()); - var->add_domain(std::numeric_limits::max()); + var->add_domain(kint64min); + var->add_domain(kint64max); return true; // Will still call solve() to get the status. } } else { diff --git a/ortools/sat/sat_solver.cc b/ortools/sat/sat_solver.cc index 99fc10e7ddf..f4f726e2a79 100644 --- a/ortools/sat/sat_solver.cc +++ b/ortools/sat/sat_solver.cc @@ -34,6 +34,7 @@ #include "ortools/base/log_severity.h" #include "ortools/base/stl_util.h" #include "ortools/base/timer.h" +#include "ortools/base/types.h" #include "ortools/port/proto_utils.h" #include "ortools/port/sysinfo.h" #include "ortools/sat/clause.h" @@ -702,7 +703,7 @@ bool SatSolver::ResetWithGivenAssumptions( // For assumptions and core-based search, it is really important to add as // many binary clauses as possible. This is because we do not want to miss any // early core of size 2. - ProcessNewlyFixedVariables(); + if (!ProcessNewlyFixedVariables()) return false; DCHECK(assumptions_.empty()); assumption_level_ = 1; @@ -1150,7 +1151,7 @@ void SatSolver::ProcessCurrentConflict( } // Learn the new clauses. - int best_lbd = std::numeric_limits::max(); + int best_lbd = kint32max; for (const auto& [clause, is_redundant, min_lbd, literals] : learned_clauses_) { DCHECK((lrat_proof_handler_ == nullptr) || (clause != kNullClausePtr)); @@ -1196,7 +1197,7 @@ std::pair SatSolver::SubsumptionsInConflictResolution( // conflict "shrinking" in the literature. std::vector subsumed_by_decisions; bool decision_is_redundant = true; - int decision_min_lbd = std::numeric_limits::max(); + int decision_min_lbd = kint32max; int decisions_clause_size = 0; if (assumption_level_ == 0 && parameters_->decision_subsumption_during_conflict_analysis()) { @@ -1282,7 +1283,7 @@ std::pair SatSolver::SubsumptionsInConflictResolution( // Then this clause subsumes all entry in the group. bool new_clause_is_redundant = true; - int new_clause_min_lbd = std::numeric_limits::max(); + int new_clause_min_lbd = kint32max; for (SatClause* clause : subsuming_groups_[i]) { CHECK_NE(clause->size(), 0); // Not subsumed yet. if (clauses_propagator_->IsRemovable(clause)) { @@ -1306,7 +1307,7 @@ std::pair SatSolver::SubsumptionsInConflictResolution( } bool is_redundant = true; - int min_lbd_of_subsumed_clauses = std::numeric_limits::max(); + int min_lbd_of_subsumed_clauses = kint32max; const auto in_decision = tmp_decision_set_.const_view(); const auto maybe_subsume = [&is_redundant, &min_lbd_of_subsumed_clauses, in_conflict, conflict, in_decision, @@ -1672,7 +1673,7 @@ SatSolver::Status SatSolver::SolveInternal(TimeLimit* time_limit, const int64_t kDisplayFrequency = 10000; int64_t next_display = parameters_->log_search_progress() ? NextMultipleOf(num_failures(), kDisplayFrequency) - : std::numeric_limits::max(); + : kint64max; // Variables used to check the memory limit every kMemoryCheckFrequency. const int64_t kMemoryCheckFrequency = 10000; @@ -1682,8 +1683,8 @@ SatSolver::Status SatSolver::SolveInternal(TimeLimit* time_limit, // The max_number_of_conflicts is per solve but the counter is for the whole // solver. const int64_t kFailureLimit = - max_number_of_conflicts == std::numeric_limits::max() - ? std::numeric_limits::max() + max_number_of_conflicts == kint64max + ? kint64max : counters_.num_failures + max_number_of_conflicts; // Starts search. @@ -2022,10 +2023,10 @@ std::string SatSolver::RunningStatisticsString() const { num_variables_.value() - num_processed_fixed_variables_); } -void SatSolver::ProcessNewlyFixedVariables() { +bool SatSolver::ProcessNewlyFixedVariables() { SCOPED_TIME_STAT(&stats_); DCHECK_EQ(CurrentDecisionLevel(), 0); - if (num_processed_fixed_variables_ == trail_->Index()) return; + if (num_processed_fixed_variables_ == trail_->Index()) return true; num_processed_fixed_variables_ = trail_->Index(); int num_detached_clauses = 0; @@ -2066,15 +2067,13 @@ void SatSolver::ProcessNewlyFixedVariables() { // fixed literals, that is okay, we will clean them up on the next call to // ProcessNewlyFixedVariables(). // - // TODO(user): This still happen in SAT22.Carry_Save_Fast_1.cnf.cnf.xz, A - // better alternative is probably to make sure we only ever have cleaned - // clauses. We must clean them each time - // binary_implication_graph_->DetectEquivalence() is called, and we need - // to make sure we don't generate new clauses that are not cleaned up. + // Note that this is hard to avoid, because as we find new equivalences, + // more clause can become binary which might cause new equivalences... And + // making sure we reach a fix-point each time DetectEquivalence() is + // called is not so easy. if (trail_->Index() > saved_index) { if (!FinishPropagation()) { - SetModelUnsat(); - return; + return SetModelUnsat(); } saved_index = trail_->Index(); } @@ -2101,6 +2100,7 @@ void SatSolver::ProcessNewlyFixedVariables() { CHECK(binary_implication_graph_->Propagate(trail_)); binary_implication_graph_->RemoveFixedVariables(); deterministic_time_of_last_fixed_variables_cleanup_ = deterministic_time(); + return true; } bool SatSolver::PropagationIsDone() const { @@ -2149,6 +2149,22 @@ bool SatSolver::Propagate() { if (trail_->Index() == old_index) break; } + // We are back at level 0. This can happen because of a restart, or because + // we proved that some variables must take a given value in any satisfiable + // assignment. Trigger a simplification of the clauses if there is new fixed + // variables. Note that for efficiency reason, we don't do that too often. + // + // TODO(user): Do more advanced preprocessing? + if (CurrentDecisionLevel() == 0) { + const double kMinDeterministicTimeBetweenCleanups = 1.0; + if (num_processed_fixed_variables_ < trail_->Index() && + deterministic_time() > + deterministic_time_of_last_fixed_variables_cleanup_ + + kMinDeterministicTimeBetweenCleanups) { + if (!ProcessNewlyFixedVariables()) return false; + } + } + // In some corner cases, we might add new constraint during propagation, // which might trigger new propagator addition or some propagator to become // non-empty() now. @@ -2228,22 +2244,6 @@ void SatSolver::EnqueueNewDecision(Literal literal) { SCOPED_TIME_STAT(&stats_); CHECK(!Assignment().VariableIsAssigned(literal.Variable())); - // We are back at level 0. This can happen because of a restart, or because - // we proved that some variables must take a given value in any satisfiable - // assignment. Trigger a simplification of the clauses if there is new fixed - // variables. Note that for efficiency reason, we don't do that too often. - // - // TODO(user): Do more advanced preprocessing? - if (CurrentDecisionLevel() == 0) { - const double kMinDeterministicTimeBetweenCleanups = 1.0; - if (num_processed_fixed_variables_ < trail_->Index() && - deterministic_time() > - deterministic_time_of_last_fixed_variables_cleanup_ + - kMinDeterministicTimeBetweenCleanups) { - ProcessNewlyFixedVariables(); - } - } - counters_.num_branches++; last_decision_or_backtrack_trail_index_ = trail_->Index(); trail_->EnqueueSearchDecision(literal); @@ -2737,12 +2737,11 @@ void SatSolver::MinimizeConflictRecursively(std::vector* conflict, is_independent_.ClearAndResize(num_variables_); // min_trail_index_per_level_ will always be reset to all - // std::numeric_limits::max() at the end. This is used to prune the + // kint32max at the end. This is used to prune the // search because any literal at a given level with an index smaller or equal // to min_trail_index_per_level_[level] can't be redundant. if (CurrentDecisionLevel() >= min_trail_index_per_level_.size()) { - min_trail_index_per_level_.resize(CurrentDecisionLevel() + 1, - std::numeric_limits::max()); + min_trail_index_per_level_.resize(CurrentDecisionLevel() + 1, kint32max); } // Compute the number of variables at each decision level. This will be used @@ -2809,8 +2808,7 @@ void SatSolver::MinimizeConflictRecursively(std::vector* conflict, const int threshold = min_trail_index_per_level_.size() / 2; if (is_marked_.PositionsSetAtLeastOnce().size() < threshold) { for (BooleanVariable var : is_marked_.PositionsSetAtLeastOnce()) { - min_trail_index_per_level_[AssignmentLevel(var)] = - std::numeric_limits::max(); + min_trail_index_per_level_[AssignmentLevel(var)] = kint32max; } } else { min_trail_index_per_level_.clear(); diff --git a/ortools/sat/sat_solver.h b/ortools/sat/sat_solver.h index 10ac04f1df4..5c6871c1896 100644 --- a/ortools/sat/sat_solver.h +++ b/ortools/sat/sat_solver.h @@ -35,6 +35,7 @@ #include "absl/log/check.h" #include "absl/types/span.h" #include "ortools/base/timer.h" +#include "ortools/base/types.h" #include "ortools/sat/clause.h" #include "ortools/sat/enforcement.h" #include "ortools/sat/lrat_proof_handler.h" @@ -100,7 +101,7 @@ class SatSolver { const int num_vars = NumVariables(); // We need to be able to encode the variable as a literal. - CHECK_LT(2 * num_vars, std::numeric_limits::max()); + CHECK_LT(2 * num_vars, kint32max); SetNumVariables(num_vars + 1); return BooleanVariable(num_vars); } @@ -387,7 +388,7 @@ class SatSolver { // It is important to process the newly fixed variables, so they are not // present in the clauses we export. if (num_processed_fixed_variables_ < trail_->Index()) { - ProcessNewlyFixedVariables(); + if (!ProcessNewlyFixedVariables()) return false; } clauses_propagator_->DeleteRemovedClauses(); @@ -506,7 +507,7 @@ class SatSolver { } // Simplifies the problem when new variables are assigned at level 0. - void ProcessNewlyFixedVariables(); + bool ProcessNewlyFixedVariables(); int64_t NumFixedVariables() const { if (CurrentDecisionLevel() > 0) { @@ -1078,8 +1079,9 @@ inline std::function ExcludeCurrentSolutionAndBacktrack() { for (int i = 0; i < current_level; ++i) { clause_to_exclude_solution.push_back(decisions[i].literal.Negated()); } - sat_solver->Backtrack(0); - AddClauseConstraint(clause_to_exclude_solution, model); + if (sat_solver->ResetToLevelZero()) { + AddClauseConstraint(clause_to_exclude_solution, model); + } }; } diff --git a/ortools/sat/sat_sweeping.cc b/ortools/sat/sat_sweeping.cc index 2d14152df58..9e8f3793d23 100644 --- a/ortools/sat/sat_sweeping.cc +++ b/ortools/sat/sat_sweeping.cc @@ -240,7 +240,7 @@ bool EquivalenceSatSweeping::DoOneRound( constexpr int kMaxTries = 10; for (tries = 0; tries < kMaxTries; ++tries) { next_candidate_var = RepresentativeVar( - BooleanVariable(absl::Uniform(*random_, 0, num_vars))); + BooleanVariable(absl::Uniform(random_, 0, num_vars))); if (var_to_clauses_[next_candidate_var].size() < 2) continue; const Literal positive_lit(next_candidate_var, true); if (implication_graph_->RepresentativeOf(positive_lit) != @@ -355,7 +355,7 @@ bool EquivalenceSatSweeping::DoOneRound( } else if (!result.new_equivalences.empty()) { // Try a different variable from the same neighborhood. const int var_index = - absl::Uniform(*random_, 0, bools_with_new_equivalences.size()); + absl::Uniform(random_, 0, bools_with_new_equivalences.size()); const BooleanVariable unmapped_next_candidate_var = bools_with_new_equivalences[var_index]; next_candidate_var = RepresentativeVar( diff --git a/ortools/sat/sat_sweeping.h b/ortools/sat/sat_sweeping.h index df9402f37ed..24240595622 100644 --- a/ortools/sat/sat_sweeping.h +++ b/ortools/sat/sat_sweeping.h @@ -55,7 +55,7 @@ class EquivalenceSatSweeping { implication_graph_(model->GetOrCreate()), clause_manager_(model->GetOrCreate()), global_time_limit_(model->GetOrCreate()), - random_(model->GetOrCreate()) {} + random_(*model->GetOrCreate()) {} // Do one round of equivalence SAT sweeping. // `run_inprocessing` is a function that is called on the model before solving @@ -79,7 +79,7 @@ class EquivalenceSatSweeping { BinaryImplicationGraph* implication_graph_; ClauseManager* clause_manager_; TimeLimit* global_time_limit_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; int max_num_clauses_ = 52000; int max_num_boolean_variables_ = 2000; diff --git a/ortools/sat/scheduling_helpers.cc b/ortools/sat/scheduling_helpers.cc index 0c35c51eaa2..10217d049ab 100644 --- a/ortools/sat/scheduling_helpers.cc +++ b/ortools/sat/scheduling_helpers.cc @@ -52,6 +52,8 @@ SchedulingConstraintHelper::SchedulingConstraintHelper( root_level_lin2_bounds_(model->GetOrCreate()), enforcement_helper_(*model->GetOrCreate()), enforcement_id_(-1), + fixed_search_(model->GetOrCreate()->search_branching() == + SatParameters::FIXED_SEARCH), starts_(std::move(starts)), ends_(std::move(ends)), sizes_(std::move(sizes)), @@ -94,6 +96,8 @@ SchedulingConstraintHelper::SchedulingConstraintHelper(int num_tasks, root_level_lin2_bounds_(model->GetOrCreate()), enforcement_helper_(*model->GetOrCreate()), enforcement_id_(-1), + fixed_search_(model->GetOrCreate()->search_branching() == + SatParameters::FIXED_SEARCH), capacity_(num_tasks), cached_size_min_(new IntegerValue[capacity_]), cached_start_min_(new IntegerValue[capacity_]), @@ -111,6 +115,10 @@ bool SchedulingConstraintHelper::IsEnforced() const { EnforcementStatus::IS_ENFORCED; } +bool SchedulingConstraintHelper::FixedSearchFirstSolutionMode() const { + return fixed_search_ && sat_solver_->num_failures() == 0; +} + bool SchedulingConstraintHelper::Propagate() { if (!IsEnforced()) return true; recompute_all_cache_ = true; diff --git a/ortools/sat/scheduling_helpers.h b/ortools/sat/scheduling_helpers.h index 8ed4ebe102f..5e29acd34e3 100644 --- a/ortools/sat/scheduling_helpers.h +++ b/ortools/sat/scheduling_helpers.h @@ -118,6 +118,12 @@ class SchedulingConstraintHelper : public PropagatorInterface { // Returns true if and only if all the enforcement literals are true. bool IsEnforced() const; + // Returns true if the search branching strategy is FIXED_SEARCH and if the + // SAT solver has not made any backtrack yet. In this case the scheduling + // search heuristics is strong enough to allow skipping some scheduling + // propagators. + bool FixedSearchFirstSolutionMode() const; + // This is a propagator so we can "cache" all the intervals relevant // information. This gives good speedup. Note however that the info is stale // except if a bound was pushed by this helper or if this was called. We run @@ -132,6 +138,7 @@ class SchedulingConstraintHelper : public PropagatorInterface { // This is used by NoOverlap2DConstraintHelper, which registers itself but // does not register its x and y SchedulingConstraintHelpers. void SetEnforcementId(EnforcementId id) { enforcement_id_ = id; } + void RecomputeCache(int t) { recompute_cache_.Set(t); } // Resets the class to the same state as if it was constructed with // the given subset of tasks from other (and the same enforcement literals). @@ -477,6 +484,7 @@ class SchedulingConstraintHelper : public PropagatorInterface { RootLevelLinear2Bounds* root_level_lin2_bounds_; EnforcementHelper& enforcement_helper_; EnforcementId enforcement_id_; + bool fixed_search_; FixedCapacityVector scratch_task_time_vector1_; FixedCapacityVector scratch_task_time_vector2_; diff --git a/ortools/sat/scheduling_local_search.cc b/ortools/sat/scheduling_local_search.cc new file mode 100644 index 00000000000..f629b4d90ba --- /dev/null +++ b/ortools/sat/scheduling_local_search.cc @@ -0,0 +1,1062 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/scheduling_local_search.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/log/vlog_is_on.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/distributions.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "ortools/base/types.h" +#include "ortools/graph_base/topologicalsorter.h" +#include "ortools/sat/combine_solutions.h" +#include "ortools/sat/cp_model_checker.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/scheduling_model.h" +#include "ortools/sat/stat_tables.h" +#include "ortools/sat/subsolver.h" +#include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" +#include "ortools/util/random_engine.h" +#include "ortools/util/time_limit.h" + +namespace operations_research { +namespace sat { + +CompactVectorVector SchedulingLocalSearch::BuildInitialMachineSequences( + absl::Span initial_solution, + absl::Span current_machines) const { + CompactVectorVectorBuilder machine_tasks_builder; + machine_tasks_builder.ReserveNumItems(num_tasks_); + for (int i = 0; i < num_tasks_; ++i) { + machine_tasks_builder.Add(current_machines[i], i); + } + CompactVectorVector machine_tasks; + machine_tasks.ResetFromBuilder(machine_tasks_builder, num_machines_); + + for (int machine = 0; machine < num_machines_; ++machine) { + absl::Span tasks_on_machine = machine_tasks[machine]; + absl::c_sort(tasks_on_machine, [&initial_solution](int a, int b) { + return initial_solution[a] < initial_solution[b]; + }); + } + + return machine_tasks; +} + +std::vector +SchedulingLocalSearch::GenerateN8Moves( + absl::Span critical_path, absl::Span prev_on_machine, + absl::Span next_on_machine, + absl::Span start_mins, + absl::Span tails, + absl::Span current_machines, + absl::Span current_durations) const { + if (critical_path.size() < 2) return {}; + + // Heuristics to forbid moves that can potentially create cycles. + auto is_invalid_forward_move = [&](int x, int target) { + if (target == -1) return false; + + // This check is not needed for classical JSSP where there is no precedences + // between operations on the same machine. + if (job_reachability_[x * num_tasks_ + target]) return true; + + for (const int js : job_successors_[x]) { + // Another check that is not needed for classical JSSP where there are no + // precedences between operations on the same machine. If a static + // successor is on the same machine, we cannot leap over it. If it starts + // before or at the target, we are leaping over it (or landing on it), + // which creates a cycle. + if (current_machines[js] == current_machines[x]) { + if (start_mins[js] <= start_mins[target]) return true; + } + + // Xie et al. Prop 1: moving x here provably worsens the makespan or + // forms an indirect cycle. + if (current_durations[target] + tails[target] < tails[js]) return true; + } + return false; + }; + + auto is_invalid_backward_move = [&](int x, int target) { + if (target == -1) return false; + + if (job_reachability_[target * num_tasks_ + x]) return true; + + for (const int jp : + problem_.tasks[x].tasks_that_must_complete_before_this) { + if (current_machines[jp] == current_machines[x]) { + if (start_mins[jp] >= start_mins[target]) return true; + } + + if (start_mins[target] + current_durations[target] < + start_mins[jp] + current_durations[jp]) { + return true; + } + } + return false; + }; + + struct Block { + // Indexes are relative to the position on the critical_path vector. + int start_idx; + int end_idx; + }; + + // A block is a sequence of tasks in the critical path that must be executed + // in sequence on the same machine. + std::vector blocks; + blocks.reserve(critical_path.size()); + + // Fast lookup table to prevent external leaps from inserting tasks back into + // other sections of the critical path. + std::vector is_critical(num_tasks_, false); + + int current_start = 0; + for (int i = 0; i <= critical_path.size(); ++i) { + if (i != critical_path.size()) { + is_critical[critical_path[i]] = true; + } + if (i == critical_path.size() || + current_machines[critical_path[i]] != + current_machines[critical_path[current_start]]) { + if (i - current_start >= 1) { + blocks.push_back({current_start, i - 1}); + } + current_start = i; + } + } + + const int num_blocks = blocks.size(); + if (num_blocks == 0) return {}; + + std::vector moves; + moves.reserve(critical_path.size() * 4); + + // First op restrictions and leftward leaps. + // Applies to all blocks except the first one. + for (int b = 1; b < num_blocks; ++b) { + const int start = blocks[b].start_idx; + const int end = blocks[b].end_idx; + const int first_op = critical_path[start]; + + // Prop 5: Inner ops moving left before first_op. + for (int i = start + 1; i <= end; ++i) { + if (!is_invalid_backward_move(critical_path[i], first_op)) { + moves.push_back({critical_path[i], first_op, false}); + } + } + + // Prop 3: first_op moving right after inner ops. + for (int i = start + 2; i < end; ++i) { + const int target = critical_path[i]; + if (is_invalid_forward_move(first_op, target)) break; // Monotonic prune. + moves.push_back({first_op, target, true}); + } + + // N8 external leaps (leftward). + int target = prev_on_machine[first_op]; + while (target != -1 && !is_critical[target]) { + bool pushed_any = false; + for (int i = start; i <= end; ++i) { + if (!is_invalid_backward_move(critical_path[i], target)) { + moves.push_back({critical_path[i], target, false}); + pushed_any = true; + } + } + if (!pushed_any) break; // Monotonic prune. + target = prev_on_machine[target]; + } + } + + // Last op restrictions and rightward leaps. + // Applies to all blocks except the last one. + for (int b = 0; b < num_blocks - 1; ++b) { + const int start = blocks[b].start_idx; + const int end = blocks[b].end_idx; + const int last_op = critical_path[end]; + + // Prop 6: Inner ops moving RIGHT after last_op. + for (int i = start; i < end; ++i) { + if (!is_invalid_forward_move(critical_path[i], last_op)) { + moves.push_back({critical_path[i], last_op, true}); + } + } + + // Prop 4: last_op moving LEFT before inner ops. + for (int i = end - 2; i > start; --i) { + const int target = critical_path[i]; + if (is_invalid_backward_move(last_op, target)) break; // Monotonic prune. + moves.push_back({last_op, target, false}); + } + + // N8 external leaps (rightward). + int target = next_on_machine[last_op]; + while (target != -1 && !is_critical[target]) { + bool pushed_any = false; + // Evaluate ALL operations in the block, including last_op! + for (int i = start; i <= end; ++i) { + if (!is_invalid_forward_move(critical_path[i], target)) { + moves.push_back({critical_path[i], target, true}); + pushed_any = true; + } + } + if (!pushed_any) break; // Monotonic prune. + target = next_on_machine[target]; + } + } + + return moves; +} + +IntegerValue SchedulingLocalSearch::EstimateMakespanForInsert( + const InsertMove& move, absl::Span start_mins, + absl::Span tails, absl::Span prev_on_machine, + absl::Span next_on_machine, + absl::Span current_durations, + MoveEvaluationScratch* scratch) const { + const int x = move.task; + const int target = move.target_task; + DCHECK_NE(x, target); + + // Helper to get the earliest possible start_min for a task ignoring the + // current machine attribution for this task only. + auto get_job_head = [&](int task) { + IntegerValue j_head = problem_.tasks[task].min_start; + for (int p : problem_.tasks[task].tasks_that_must_complete_before_this) { + j_head = std::max(j_head, start_mins[p] + current_durations[p]); + } + return j_head; + }; + + // Same as above, but for the job tail. + auto get_job_tail = [&](int task) { + IntegerValue j_tail = 0; + for (int s : job_successors_[task]) { + j_tail = std::max(j_tail, current_durations[s] + tails[s]); + } + return j_tail; + }; + + // Build the contiguous mutated sequence of tasks on this machine. + // + // Note that in the original Balas & Vazacopoulos paper, the moved task 'u' + // and target 'v' (our 'x' and 'target') are picked so all tasks between + // them on the critical path are on the same machine. The article refers to + // the 'Q' set of tasks between 'u' and 'v' as a subsequence of the critical + // path. However, it is also technically correct to say that 'Q' is simply a + // subsequence of the machine tasks. + // + // Because we generalize this for the N8 leaps from Xie et al., our tasks + // are picked from the same machine, but the sequence between them is not + // necessarily on the critical path. `mutated_sequence` below represents + // this generalized displaced machine segment (Q) combined with the moved + // task (x). + std::vector& mutated_sequence = scratch->mutated_sequence; + mutated_sequence.clear(); + int p_m = -1; // The machine predecessor to the entire mutated segment. + int s_m = -1; // The machine successor to the entire mutated segment. + + if (move.place_after) { + p_m = prev_on_machine[x]; + s_m = next_on_machine[target] == x ? next_on_machine[x] + : next_on_machine[target]; + + int curr = next_on_machine[x]; + while (curr != -1) { + mutated_sequence.push_back(curr); + if (curr == target) break; + curr = next_on_machine[curr]; + } + mutated_sequence.push_back(x); + } else { + p_m = prev_on_machine[target] == x ? prev_on_machine[x] + : prev_on_machine[target]; + s_m = next_on_machine[x]; + + mutated_sequence.push_back(x); + const int stop_node = prev_on_machine[x]; + int curr = target; + while (curr != -1) { + mutated_sequence.push_back(curr); + if (curr == stop_node) break; + curr = next_on_machine[curr]; + } + } + + // Forward sweep to calculate new heads. + std::vector& new_heads = scratch->heads; + new_heads.resize(mutated_sequence.size()); + IntegerValue head = 0; + if (p_m != -1) { + head = start_mins[p_m] + current_durations[p_m]; + } + + for (size_t i = 0; i < mutated_sequence.size(); ++i) { + const int t = mutated_sequence[i]; + head = std::max(head, get_job_head(t)); + new_heads[i] = head; + head += current_durations[t]; + } + + // Backward sweep to calculate new tails. + std::vector& new_tails = scratch->tails; + new_tails.resize(mutated_sequence.size()); + IntegerValue tail = 0; + if (s_m != -1) { + tail = current_durations[s_m] + tails[s_m]; + } + + for (int i = mutated_sequence.size() - 1; i >= 0; --i) { + const int t = mutated_sequence[i]; + tail = std::max(tail, get_job_tail(t)); + new_tails[i] = tail; + tail += current_durations[t]; + } + + // Find the maximum path through the mutated segment. + IntegerValue max_path = 0; + for (size_t i = 0; i < mutated_sequence.size(); ++i) { + const IntegerValue path = + new_heads[i] + current_durations[mutated_sequence[i]] + new_tails[i]; + max_path = std::max(max_path, path); + } + + return max_path; +} + +std::optional +SchedulingLocalSearch::SelectBestMove( + absl::Span candidates, + absl::Span start_mins, + absl::Span tails, absl::Span prev_on_machine, + absl::Span next_on_machine, absl::Span tabu_matrix, + absl::Span current_durations, int current_iteration, + IntegerValue global_best_makespan, absl::BitGenRef random) const { + DCHECK_EQ(start_mins.size(), num_tasks_); + DCHECK_EQ(tails.size(), num_tasks_); + DCHECK_EQ(prev_on_machine.size(), num_tasks_); + DCHECK_EQ(next_on_machine.size(), num_tasks_); + + std::optional best_move = std::nullopt; + IntegerValue best_estimate = kMaxIntegerValue; + int tie_count = 0; + + std::optional best_tabu_move = std::nullopt; + IntegerValue best_tabu_estimate = kMaxIntegerValue; + int tabu_tie_count = 0; + + MoveEvaluationScratch scratch; + for (const InsertMove& move : candidates) { + const int x = move.task; + const int target = move.target_task; + + bool is_tabu = false; + { + // Tabu check + int new_prev = move.place_after ? target : prev_on_machine[target]; + if (new_prev == x) new_prev = prev_on_machine[x]; + + int new_next = move.place_after ? next_on_machine[target] : target; + if (new_next == x) new_next = next_on_machine[x]; + + if (new_prev != -1) { + int u = x; + int v = new_prev; + if (u > v) std::swap(u, v); + DCHECK_LT(u * num_tasks_ + v, tabu_matrix.size()); + if (tabu_matrix[u * num_tasks_ + v] > current_iteration) is_tabu = true; + } + if (!is_tabu && new_next != -1) { + int u = x; + int v = new_next; + if (u > v) std::swap(u, v); + DCHECK_LT(u * num_tasks_ + v, tabu_matrix.size()); + if (tabu_matrix[u * num_tasks_ + v] > current_iteration) is_tabu = true; + } + } + + // Evaluate an approximation of the makespan if we make this move. + const IntegerValue estimate = + EstimateMakespanForInsert(move, start_mins, tails, prev_on_machine, + next_on_machine, current_durations, &scratch); + + // 3. Tabu Search Logic + Aspiration Criterion + // If it's NOT tabu, it's a valid candidate. + // If it IS tabu, we only allow it if the estimate beats our all-time global + // best schedule. + if (!is_tabu || estimate < global_best_makespan) { + if (estimate < best_estimate) { + // We found a strictly better move! + best_estimate = estimate; + best_move = move; + tie_count = 1; + } else if (estimate == best_estimate) { + // Use Reservoir Sampling for tie-breaking. + tie_count++; + // There is exactly a 1-in-N chance of replacing the current best. + if (absl::Uniform(random, 0, tie_count) == 0) { + best_move = move; + } + } + } else { + // Track the best forbidden move just in case we get trapped. + if (estimate < best_tabu_estimate) { + best_tabu_estimate = estimate; + best_tabu_move = move; + tabu_tie_count = 1; + } else if (estimate == best_tabu_estimate) { + tabu_tie_count++; + if (absl::Uniform(random, 0, tabu_tie_count) == 0) { + best_tabu_move = move; + } + } + } + } + + // If we are completely boxed in by the Tabu list, use the best Tabu move. + if (!best_move.has_value() && best_tabu_move.has_value()) { + return best_tabu_move; + } + + return best_move; +} + +CompactVectorVector SchedulingLocalSearch::ComputeJobSuccessors( + const SchedulingProblem& problem) { + const int num_tasks = problem.tasks.size(); + CompactVectorVectorBuilder job_successors_builder; + job_successors_builder.ReserveNumItems(num_tasks); + + for (int v = 0; v < num_tasks; ++v) { + for (int u : problem.tasks[v].tasks_that_must_complete_before_this) { + job_successors_builder.Add(u, v); + } + } + return CompactVectorVector(job_successors_builder, num_tasks); +} + +SchedulingLocalSearch::SolverState SchedulingLocalSearch::ComputeDynamicState( + const CompactVectorVector& machine_tasks, + absl::Span topo_order, + absl::Span current_durations) const { + SolverState state; + state.prev_on_machine.assign(num_tasks_, -1); + state.next_on_machine.assign(num_tasks_, -1); + state.position_in_machine.assign(num_tasks_, -1); + state.tails.assign(num_tasks_, 0); + + // 1. Build O(1) Machine Successor/Predecessor lookups + for (int m = 0; m < machine_tasks.size(); ++m) { + const absl::Span tasks_on_machine = machine_tasks[m]; + for (int i = 0; i < tasks_on_machine.size(); ++i) { + const int u = tasks_on_machine[i]; + state.position_in_machine[u] = i; + if (i > 0) { + state.prev_on_machine[u] = tasks_on_machine[i - 1]; + } + if (i < tasks_on_machine.size() - 1) { + state.next_on_machine[u] = tasks_on_machine[i + 1]; + } + } + } + + // 2. Calculate Tails (Reverse Topological Traversal) + // We iterate backwards through the topological order. By the time we visit + // a node, all of its successors have already finalized their tails. + for (int i = num_tasks_ - 1; i >= 0; --i) { + const int u = topo_order[i]; + IntegerValue max_tail = 0; + + // A. Check Job Successors + for (const int s : job_successors_[u]) { + const IntegerValue path_len = current_durations[s] + state.tails[s]; + if (path_len > max_tail) { + max_tail = path_len; + } + } + + // B. Check Machine Successor + const int m_succ = state.next_on_machine[u]; + if (m_succ != -1) { + const IntegerValue path_len = + current_durations[m_succ] + state.tails[m_succ]; + if (path_len > max_tail) { + max_tail = path_len; + } + } + + state.tails[u] = max_tail; + } + + return state; +} + +void SchedulingLocalSearch::AnalyzeSchedule( + const CompactVectorVector& machine_tasks, + absl::Span current_durations, + SchedulingLocalSearch::ScheduleAnalysis* analysis) const { + // Build the precedences graph, taking into account the order of the tasks on + // the machines and the task precedences. + std::vector prev_on_machine(num_tasks_, -1); + CompactVectorVector& adj = analysis->scratch_adjacency_list; + + for (int machine = 0; machine < num_machines_; ++machine) { + const auto& tasks_on_machine = machine_tasks[machine]; + for (int i = 1; i < tasks_on_machine.size(); ++i) { + int u = tasks_on_machine[i - 1]; + int v = tasks_on_machine[i]; + prev_on_machine[v] = u; + adj[u][adj[u].size() - 1] = v; + } + if (!tasks_on_machine.empty()) { + // Make sure we overwrite the fake next task for every task. + absl::Span adj_last = adj[tasks_on_machine.back()]; + adj_last[adj_last.size() - 1] = num_tasks_; + } + } + + // Now run a topological sort and compute the earliest start time of each + // task by forcing it to be equal the maximal completion time of all its + // predecessors. + absl::StatusOr> maybe_topo_order = + util::graph::FastTopologicalSort(adj); + CHECK_OK(maybe_topo_order); + analysis->topo_order = std::move(*maybe_topo_order); + CHECK_EQ(analysis->topo_order.size(), num_tasks_ + 1); + CHECK_EQ(analysis->topo_order.back(), num_tasks_); + analysis->topo_order.pop_back(); + + // Find the start_min + analysis->start_mins.assign(num_tasks_, 0); + for (int i = 0; i < num_tasks_; ++i) { + analysis->start_mins[i] = problem_.tasks[i].min_start; + } + + for (int u : analysis->topo_order) { + if (u == num_tasks_) continue; + const IntegerValue completion_time_u = + analysis->start_mins[u] + current_durations[u]; + for (int v : adj[u]) { + if (v == num_tasks_) continue; + if (completion_time_u > analysis->start_mins[v]) { + analysis->start_mins[v] = completion_time_u; + } + } + } + + // Use our calculated EST to find the makespan and the last task. + int last_task = -1; + analysis->makespan = -1; + for (int i = 0; i < num_tasks_; ++i) { + const IntegerValue comp_time = + analysis->start_mins[i] + current_durations[i]; + if (comp_time > analysis->makespan) { + analysis->makespan = comp_time; + last_task = i; + } + } + + // Start from the last task and look for the reason it starts at the current + // EST, which must be a predecessor that ends at the current task EST. + int curr = last_task; + analysis->critical_path.clear(); + analysis->critical_path.reserve(num_tasks_ / num_machines_); + + while (curr != -1) { + analysis->critical_path.push_back(curr); + if (analysis->start_mins[curr] == problem_.tasks[curr].min_start) break; + + int tightest_pred = -1; + for (const int pred : + problem_.tasks[curr].tasks_that_must_complete_before_this) { + if (analysis->start_mins[pred] + current_durations[pred] == + analysis->start_mins[curr]) { + tightest_pred = pred; + break; + } + } + + // If no job predecessor is the bottleneck, check the machine predecessor + if (tightest_pred == -1) { + int m_pred = prev_on_machine[curr]; + if (m_pred != -1 && + analysis->start_mins[m_pred] + current_durations[m_pred] == + analysis->start_mins[curr]) { + tightest_pred = m_pred; + } + } + curr = tightest_pred; + } + + absl::c_reverse(analysis->critical_path); +} + +std::vector SchedulingLocalSearch::ComputeJobReachability( + const SchedulingProblem& problem, + const CompactVectorVector& job_successors) { + const int num_tasks = problem.tasks.size(); + // Compute static job reachability to prevent cyclic swaps + std::vector job_reachability(num_tasks * num_tasks, false); + std::vector stack; + stack.reserve(num_tasks); + for (int i = 0; i < num_tasks; ++i) { + stack.assign({i}); + while (!stack.empty()) { + int curr = stack.back(); + stack.pop_back(); + for (int succ : job_successors[curr]) { + if (!job_reachability[i * num_tasks + succ]) { + job_reachability[i * num_tasks + succ] = true; + stack.push_back(succ); + } + } + } + } + return job_reachability; +} + +SchedulingLocalSearch::SchedulingLocalSearch(const SchedulingProblem& problem) + : problem_(problem), + precedences_(ComputePrecedences(problem)), + job_successors_(ComputeJobSuccessors(problem)), + job_reachability_(ComputeJobReachability(problem, job_successors_)), + num_tasks_(problem_.tasks.size()), + num_machines_(([](const SchedulingProblem& p) { + int max_m = -1; + for (const auto& t : p.tasks) { + for (int m : t.compatible_machine) { + max_m = std::max(max_m, m); + } + } + return max_m + 1; + })(problem_)) { + CHECK(!problem_.tasks.empty()); + CHECK_GT(num_machines_, 0); +} + +SchedulingLocalSearch::ScheduleAnalysis::ScheduleAnalysis( + const SchedulingProblem& problem) { + // Build the scratch adjacency list for topological sort. + std::vector> adj_edges; + for (int v = 0; v < problem.tasks.size(); ++v) { + for (int u : problem.tasks[v].tasks_that_must_complete_before_this) { + adj_edges.push_back({u, v}); + } + } + for (int i = 0; i < problem.tasks.size(); ++i) { + // Overwritable fake task with index num_tasks_. + adj_edges.push_back({i, problem.tasks.size()}); + } + scratch_adjacency_list.ResetFromPairs(adj_edges, problem.tasks.size() + 1); +} + +// static +std::vector> SchedulingLocalSearch::ComputePrecedences( + const SchedulingProblem& problem) { + std::vector> precedences; + precedences.reserve(problem.tasks.size()); + for (int v = 0; v < problem.tasks.size(); ++v) { + for (int u : problem.tasks[v].tasks_that_must_complete_before_this) { + precedences.push_back({u, v}); + } + } + return precedences; +} + +std::vector SchedulingLocalSearch::Solve( + absl::Span initial_hint, + absl::Span active_machine_indices, int64_t makespan_to_beat, + absl::BitGenRef random, TimeLimit* time_limit) const { + if (num_tasks_ > 40000) { + // TODO(user): Support huge problems. + return {}; + } + + std::vector current_machines(num_tasks_); + std::vector current_durations(num_tasks_); + for (int i = 0; i < num_tasks_; ++i) { + const int a_idx = active_machine_indices[i]; + current_machines[i] = problem_.tasks[i].compatible_machine[a_idx]; + current_durations[i] = problem_.tasks[i].duration_for_machine[a_idx]; + } + + std::vector tabu_matrix(num_tasks_ * num_tasks_, 0); + + // Build the initial sequence state from the external hint + CompactVectorVector current_machine_tasks = + BuildInitialMachineSequences(initial_hint, current_machines); + + IntegerValue initial_makespan = -1; + + CompactVectorVector best_machine_tasks; + std::vector best_solution_start_mins(num_tasks_); + IntegerValue global_best_makespan = kMaxIntegerValue; + + TimeLimitCheckEveryNCalls time_limit_check(100, time_limit); + int current_iteration = 0; + const int min_iterations = 10000; + const int max_iterations = 1000000; + const int num_jobs = num_tasks_ / num_machines_; + const int baseline_tenure = 10 + num_jobs / num_machines_; + const int randomized_tenure = absl::Uniform( + random, std::max(5, static_cast(baseline_tenure * 0.5)), + static_cast(baseline_tenure * 1.5) + 1); + const int tabu_tenure_min = randomized_tenure; + const int tabu_tenure_max = randomized_tenure + randomized_tenure / 2; + const int stagnation_limit = absl::Uniform(random, 100, 400); + + ScheduleAnalysis analysis(problem_); + + int iterations_without_improvement = 0; + for (current_iteration = 0; current_iteration < max_iterations; + ++current_iteration) { + if (time_limit_check.LimitReached()) break; + + // Analyze the current sequence. + AnalyzeSchedule(current_machine_tasks, current_durations, &analysis); + + // Track our absolute best schedule + if (analysis.makespan < global_best_makespan) { + if (current_iteration == 0) { + // Keep the initial makespan for debugging and logging. + initial_makespan = analysis.makespan; + } + global_best_makespan = analysis.makespan; + best_machine_tasks = current_machine_tasks; + for (int i = 0; i < num_tasks_; ++i) { + best_solution_start_mins[i] = analysis.start_mins[i].value(); + } + iterations_without_improvement = 0; + } + ++iterations_without_improvement; + + if (analysis.makespan < makespan_to_beat && + current_iteration > min_iterations) { + break; + } + + // Compute dynamic state. + const SolverState state = ComputeDynamicState( + current_machine_tasks, analysis.topo_order, current_durations); + + const std::vector candidates = GenerateN8Moves( + analysis.critical_path, state.prev_on_machine, state.next_on_machine, + analysis.start_mins, state.tails, current_machines, current_durations); + + // Select the best move. + std::optional best_move = SelectBestMove( + candidates, analysis.start_mins, state.tails, state.prev_on_machine, + state.next_on_machine, tabu_matrix, current_durations, + current_iteration, global_best_makespan, random); + + if (iterations_without_improvement > stagnation_limit) { + // To diversify the our tabu search, when the search stagnates, we + // select a random move from the candidates. The idea is older, but the + // specific implementation is based on "Xie, J., et al., A new + // neighborhood structure for job shop scheduling problems. International + // Journal of Production Research, 1–15 (2022)". + if (!candidates.empty()) { + iterations_without_improvement = 0; + best_move = + candidates[absl::Uniform(random, 0, candidates.size())]; + } + } + + if (!best_move.has_value()) break; + + // Apply the Move directly to the schedule sequence + const int x = best_move->task; + const int target = best_move->target_task; + const int m = current_machines[x]; + + const int idx_x = state.position_in_machine[x]; + const int idx_target = state.position_in_machine[target]; + + // Determine the raw destination index based on placement + int insert_idx = idx_target + (best_move->place_after ? 1 : 0); + + // If the task is moving right, its own removal shifts the target index left + if (idx_x < insert_idx) { + insert_idx--; + } + + absl::Span machine_seq = current_machine_tasks[m]; + + // Execute the insertion in-place via std::rotate + if (idx_x < insert_idx) { + // Moving task to the right. + // Everything between idx_x+1 and insert_idx shifts left by 1. + std::rotate(machine_seq.begin() + idx_x, machine_seq.begin() + idx_x + 1, + machine_seq.begin() + insert_idx + 1); + } else if (idx_x > insert_idx) { + // Moving task to the left. + // Everything between insert_idx and idx_x-1 shifts right by 1. + std::rotate(machine_seq.begin() + insert_idx, machine_seq.begin() + idx_x, + machine_seq.begin() + idx_x + 1); + } + + // Update Tabu matrix using a random tenure length + const int tenure = + absl::Uniform(random, tabu_tenure_min, tabu_tenure_max + 1); + const int old_prev = state.prev_on_machine[x]; + const int old_next = state.next_on_machine[x]; + + // Penalize the task jumping back next to its old neighbors + if (old_prev != -1) { + tabu_matrix[x * num_tasks_ + old_prev] = current_iteration + tenure; + tabu_matrix[old_prev * num_tasks_ + x] = current_iteration + tenure; + } + if (old_next != -1) { + tabu_matrix[x * num_tasks_ + old_next] = current_iteration + tenure; + tabu_matrix[old_next * num_tasks_ + x] = current_iteration + tenure; + } + } + if (VLOG_IS_ON(2)) { + AnalyzeSchedule(best_machine_tasks, current_durations, &analysis); + } + VLOG(2) << "Initial makespan: " << initial_makespan + << " best makespan: " << analysis.makespan + << " num_iterations: " << current_iteration; + + // Return the start times of the best schedule we found. + return best_solution_start_mins; +} + +namespace { + +std::vector ComputeCpSatSolutionFromSchedulingSolution( + const SchedulingProblemAndMapping& problem_and_mapping, + const CpModelProto& input_model_proto, + absl::Span scheduling_solution, + absl::Span active_machine_indices, + absl::Span original_solution) { + // 1. Copy the entire original solution to inherit non-scheduling variables + std::vector new_solution(original_solution.begin(), + original_solution.end()); + + int64_t new_makespan = 0; + + auto assign_expr = [&](const LinearExpressionProto& expr, + int64_t target_val) { + // The division might not be exact: our definition of SchedulingRelaxation + // cannot enforce the start times to be a multiple of some coefficient. We + // might end up with a infeasible solution. + if (expr.vars().size() == 1) { + const int var = expr.vars(0); + const int64_t coeff = expr.coeffs(0); + const int64_t offset = expr.offset(); + + DCHECK_GE(var, 0); + new_solution[var] = (target_val - offset) / coeff; + } + }; + + for (int i = 0; i < problem_and_mapping.problem.tasks.size(); ++i) { + const int64_t new_start = scheduling_solution[i]; + const int active_machine_idx = active_machine_indices[i]; + + // 2. Explicitly set all intervals' optional presence booleans + // based on the dynamically provided machine choices. + const SchedulingProblem::Task& task = problem_and_mapping.problem.tasks[i]; + for (int a = 0; a < task.compatible_machine.size(); ++a) { + const int lit = problem_and_mapping.task_to_presence_literals[i][a]; + if (lit == kint32max) continue; + + const bool should_be_true = (a == active_machine_idx); + if (RefIsPositive(lit)) { + new_solution[lit] = should_be_true ? 1 : 0; + } else { + new_solution[NegatedRef(lit)] = should_be_true ? 0 : 1; + } + } + + const int64_t duration = problem_and_mapping.problem.tasks[i] + .duration_for_machine[active_machine_idx]; + const int64_t new_end = new_start + duration; + new_makespan = std::max(new_makespan, new_end); + + // 3. Set interval start, end, and size variables + const auto& task_intervals = problem_and_mapping.task_to_intervals[i]; + + // 3a. Repair the specific active alternative interval + if (active_machine_idx < task_intervals.alternative_intervals.size()) { + const int active_interval_idx = + task_intervals.alternative_intervals[active_machine_idx]; + const auto& interval = + input_model_proto.constraints(active_interval_idx).interval(); + + assign_expr(interval.start(), new_start); + assign_expr(interval.end(), new_end); + assign_expr(interval.size(), duration); + } + + // 3b. Repair all main/shared intervals + for (const int main_idx : task_intervals.unconditional_intervals) { + const auto& interval = input_model_proto.constraints(main_idx).interval(); + + assign_expr(interval.start(), new_start); + assign_expr(interval.end(), new_end); + assign_expr(interval.size(), duration); + } + } + + // 4. Update Makespan if present + if (problem_and_mapping.makespan_expr.has_value()) { + const auto& [makespan_var, coeff, offset] = + problem_and_mapping.makespan_expr.value(); + + DCHECK_GE(makespan_var, 0); + new_solution[makespan_var] = (new_makespan - offset) / coeff; + } + + return new_solution; +} + +} // namespace + +SchedulingLocalSearchSolver::SchedulingLocalSearchSolver( + const absl::string_view name, SubSolver::SubsolverType type, + const CpModelProto& input_model_proto, SatParameters params, + ModelSharedTimeLimit* shared_time_limit, + SharedResponseManager* shared_response, SharedStatTables* stat_tables) + : SubSolver(name, type), + input_model_proto_(input_model_proto), + params_(params), + shared_time_limit_(shared_time_limit), + shared_response_(shared_response), + stat_tables_(stat_tables) { + relaxation_ = DetectSchedulingProblems(input_model_proto_); + for (const auto& problem_and_mapping : relaxation_.problems_and_mappings) { + local_search_solvers_.emplace_back( + std::make_unique(problem_and_mapping.problem)); + } +} + +std::function SchedulingLocalSearchSolver::GenerateTask( + int64_t task_id) { + return [this, task_id]() { + if (relaxation_.problems_and_mappings.empty()) return; + TimeLimit task_time_limit; + shared_time_limit_->UpdateLocalLimit(&task_time_limit); + // Create a random number generator whose seed depends both on the task_id + // and on the params_.random_seed() so that changing the later will + // change the LNS behavior. + const int32_t low = static_cast(task_id); + const int32_t high = static_cast(task_id >> 32); + std::seed_seq seed{low, high, params_.random_seed()}; + random_engine_t random(seed); + const int problem_index = + absl::Uniform(random, 0, relaxation_.problems_and_mappings.size()); + if (relaxation_.problems_and_mappings[problem_index].problem.tasks.size() < + 3) + return; + const SchedulingProblemAndMapping& problem_and_mapping = + relaxation_.problems_and_mappings[problem_index]; + const auto base_solution = + shared_response_->SolutionPool().GetSolutionToImprove(random); + if (base_solution == nullptr) return; + + std::vector scheduling_solution; + scheduling_solution.reserve( + problem_and_mapping.task_to_start_time_model_var.size()); + for (int i = 0; i < problem_and_mapping.task_to_start_time_model_var.size(); + ++i) { + const auto& [var, coeff, offset] = + problem_and_mapping.task_to_start_time_model_var[i]; + scheduling_solution.push_back( + base_solution->variable_values[var] * coeff + offset); + } + + std::vector active_machine_indices( + problem_and_mapping.problem.tasks.size(), 0); + for (int i = 0; i < problem_and_mapping.problem.tasks.size(); ++i) { + const SchedulingProblem::Task& task = + problem_and_mapping.problem.tasks[i]; + int active_idx = 0; + for (int a = 0; a < task.compatible_machine.size(); ++a) { + int lit = problem_and_mapping.task_to_presence_literals[i][a]; + if (lit == kint32max) { + // A single machine choice, no enforcement needed. + active_idx = a; + break; + } + const bool is_true = + (RefIsPositive(lit)) + ? (base_solution->variable_values[lit] == 1) + : (base_solution->variable_values[NegatedRef(lit)] == 0); + if (is_true) { + active_idx = a; + break; + } + } + active_machine_indices[i] = active_idx; + } + + int64_t relaxed_objective_value = 0; + + DCHECK(VerifySchedulingRelaxation( + relaxation_, base_solution->variable_values, &relaxed_objective_value)); + DCHECK_LE(relaxed_objective_value, + ComputeInnerObjective(input_model_proto_.objective(), + base_solution->variable_values)); + SchedulingLocalSearch& local_search_solver = + *local_search_solvers_[problem_index]; + + const std::vector new_scheduling_solution = + local_search_solver.Solve( + absl::MakeConstSpan(scheduling_solution), + absl::MakeConstSpan(active_machine_indices), + shared_response_->BestSolutionInnerObjectiveValue().value(), random, + &task_time_limit); + if (new_scheduling_solution.empty()) return; + + std::vector new_solution = + ComputeCpSatSolutionFromSchedulingSolution( + problem_and_mapping, input_model_proto_, new_scheduling_solution, + active_machine_indices, base_solution->variable_values); + + if (SolutionIsFeasible(input_model_proto_, new_solution)) { + const int64_t new_objective_value = + ComputeInnerObjective(input_model_proto_.objective(), new_solution); + VLOG(2) << "New solution is feasible, objective value: " + << new_objective_value << " best: " + << shared_response_->BestSolutionInnerObjectiveValue(); + PushAndMaybeCombineSolution(shared_response_, input_model_proto_, + new_solution, this->name(), base_solution); + } else { + // TODO(user): try to fix it with ViolationLS. + VLOG(2) << "New solution is infeasible"; + } + }; +} + +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/scheduling_local_search.h b/ortools/sat/scheduling_local_search.h new file mode 100644 index 00000000000..efcaaa96d3d --- /dev/null +++ b/ortools/sat/scheduling_local_search.h @@ -0,0 +1,229 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ORTOOLS_SAT_SCHEDULING_LOCAL_SEARCH_H_ +#define ORTOOLS_SAT_SCHEDULING_LOCAL_SEARCH_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/scheduling_model.h" +#include "ortools/sat/stat_tables.h" +#include "ortools/sat/subsolver.h" +#include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" +#include "ortools/util/time_limit.h" + +namespace operations_research { +namespace sat { + +// Stand-alone class for improving solutions to scheduling problems using local +// search. The heuristic this class uses is the Nowicki & Smutnicki algorithm +// described in "Eugeniusz Nowicki, Czeslaw Smutnicki, (1996) A Fast Taboo +// Search Algorithm for the Job Shop Problem. Management Science 42(6):797-813". +class SchedulingLocalSearch { + public: + explicit SchedulingLocalSearch(const SchedulingProblem& problem); + + SchedulingLocalSearch(const SchedulingLocalSearch&) = delete; + SchedulingLocalSearch& operator=(const SchedulingLocalSearch&) = delete; + + // Run local search on the given problem and return the best solution found. + // Will stop early if the makespan is less than makespan_to_beat. Thread-safe. + std::vector Solve(absl::Span initial_hint, + absl::Span active_machine_indices, + int64_t makespan_to_beat, absl::BitGenRef random, + TimeLimit* time_limit) const; + + protected: // Exposed for testing. + struct ScheduleAnalysis { + explicit ScheduleAnalysis(const SchedulingProblem& problem); + + std::vector critical_path; + std::vector start_mins; + std::vector topo_order; + IntegerValue makespan; + + // This is used to compute the topological order on the graph of tasks + // including the precedences and the order of tasks on machines. + // This gets initialized so that scratch_adjacency_list[i] contains the + // job successors of task i and a "fake" task with index num_tasks_. To + // get a topological order of the tasks including the order of tasks on + // machines, we need to overwrite this fake task with the actual tasks that + // follow task i on the same machine, if any. + CompactVectorVector scratch_adjacency_list; + }; + + struct InsertMove { + int task; // Task to be moved + int target_task; // Reference task for the new position + bool place_after; // true: insert after target_task; false: insert before + // target_task + }; + + struct SolverState { + std::vector prev_on_machine; + std::vector next_on_machine; + std::vector position_in_machine; + // Minimum time to complete all successors of a task. + std::vector tails; + }; + + static CompactVectorVector ComputeJobSuccessors( + const SchedulingProblem& problem); + + static std::vector ComputeJobReachability( + const SchedulingProblem& problem, + const CompactVectorVector& job_successors); + + static std::vector> ComputePrecedences( + const SchedulingProblem& problem); + + SolverState ComputeDynamicState( + const CompactVectorVector& machine_tasks, + absl::Span topo_order, + absl::Span current_durations) const; + + // Builds the sequence of tasks on each machine from the start times of the + // tasks. We want to use the task ordering as the internal solution + // representation to better encode our moves as swaps. + CompactVectorVector BuildInitialMachineSequences( + absl::Span initial_solution, + absl::Span current_machines) const; + + std::optional SelectBestMove( + absl::Span candidates, + absl::Span start_mins, + absl::Span tails, + absl::Span prev_on_machine, + absl::Span next_on_machine, absl::Span tabu_matrix, + absl::Span current_durations, int current_iteration, + IntegerValue global_best_makespan, absl::BitGenRef random) const; + + // Generates the possible moves for the local search following [1]. The moves + // are generated by defining a "block", which is a contiguous set of tasks + // that are on the critical path and on the same machine. To reduce the + // makespan, the critical path must be broken by reordering tasks in these + // blocks. + // + // The N8 neighborhood generates two distinct types of intra-machine moves: + // 1. Internal block moves: moving tasks to the boundaries of their own block + // (e.g., moving an inner task before the first task, or after the last + // task). + // 2. External leaps: taking tasks from the critical block and pushing them + // into non-critical time intervals on the same machine. + // + // [1] Xie, J., Li, X., Gao, L., & Gui, L. (2023). A new neighbourhood + // structure for job shop scheduling problems. International Journal of + // Production Research, 61(7), 2147-2161. + std::vector GenerateN8Moves( + absl::Span critical_path, + absl::Span prev_on_machine, + absl::Span next_on_machine, + absl::Span start_mins, + absl::Span tails, + absl::Span current_machines, + absl::Span current_durations) const; + + struct MoveEvaluationScratch { + std::vector mutated_sequence; + std::vector heads; + std::vector tails; + }; + + // Estimates the makespan of the solution if we apply the move `move`. + // Uses the estimator described in Section 3 of "Balas, E., & Vazacopoulos, A. + // (1998). Guided local search with shifting bottleneck for job shop + // scheduling. Management science, 44(2), 262-275". + IntegerValue EstimateMakespanForInsert( + const InsertMove& move, absl::Span start_mins, + absl::Span tails, + absl::Span prev_on_machine, + absl::Span next_on_machine, + absl::Span current_durations, + MoveEvaluationScratch* scratch) const; + + // Find the critical path of a solution using a topological sort. Also + // computes the implied heads (Earliest Start Times) of each task. + void AnalyzeSchedule(const CompactVectorVector& machine_tasks, + absl::Span current_durations, + SchedulingLocalSearch::ScheduleAnalysis* analysis) const; + + private: + const SchedulingProblem& problem_; + const std::vector> precedences_; + const CompactVectorVector job_successors_; + const std::vector job_reachability_; + const int num_tasks_; + const int num_machines_; +}; + +class SchedulingLocalSearchSolver : public SubSolver { + public: + SchedulingLocalSearchSolver(absl::string_view name, + SubSolver::SubsolverType type, + const CpModelProto& input_model_proto, + SatParameters params, + ModelSharedTimeLimit* shared_time_limit, + SharedResponseManager* shared_response, + SharedStatTables* stat_tables); + + SchedulingLocalSearchSolver(const SchedulingLocalSearchSolver&) = delete; + SchedulingLocalSearchSolver& operator=(const SchedulingLocalSearchSolver&) = + delete; + + ~SchedulingLocalSearchSolver() override { + stat_tables_->AddTimingStat(*this); + } + + std::function GenerateTask(int64_t /*task_id*/) final; + void Synchronize() final {} + bool IsDone() final { return false; } + bool TaskIsAvailable() final { + if (IsDone()) return false; + if (shared_response_->ProblemIsSolved()) return false; + if (shared_time_limit_->LimitReached()) return false; + if (relaxation_.problems_and_mappings.empty()) return false; + if (absl::c_all_of( + relaxation_.problems_and_mappings, + [](const SchedulingProblemAndMapping& problem_and_mapping) { + return problem_and_mapping.problem.tasks.size() < 3; + })) { + return false; + } + return shared_response_->HasFeasibleSolution(); + } + + private: + CpModelProto input_model_proto_; + SatParameters params_; + SharedTimeLimit* shared_time_limit_; + SharedResponseManager* shared_response_; + SharedStatTables* stat_tables_; + SchedulingRelaxation relaxation_; + std::vector> local_search_solvers_; +}; + +} // namespace sat +} // namespace operations_research + +#endif // ORTOOLS_SAT_SCHEDULING_LOCAL_SEARCH_H_ diff --git a/ortools/sat/scheduling_local_search_test.cc b/ortools/sat/scheduling_local_search_test.cc new file mode 100644 index 00000000000..1498039f161 --- /dev/null +++ b/ortools/sat/scheduling_local_search_test.cc @@ -0,0 +1,207 @@ +// Copyright 2010-2025 Google LLC +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ortools/sat/scheduling_local_search.h" + +#include +#include +#include + +#include "absl/types/span.h" +#include "gtest/gtest.h" +#include "ortools/base/gmock.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/scheduling_model.h" +#include "ortools/sat/util.h" + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::UnorderedElementsAre; + +namespace operations_research { +namespace sat { +namespace { + +class JSSPSolverTest : public ::testing::Test { + protected: + void SetUp() override { + // 1. Setup the Problem + problem_.type = SchedulingProblem::kMinimizeMakespan; + + // { {compatible_machines}, {durations}, min_start, {precedences} } + problem_.tasks = {// Task 0: M0, Dur 10 + {.compatible_machine = {0}, + .duration_for_machine = {10}, + .tasks_that_must_complete_before_this = {}, + .min_start = 0}, + // Task 1: M0, Dur 10 + {.compatible_machine = {0}, + .duration_for_machine = {10}, + .tasks_that_must_complete_before_this = {}, + .min_start = 0}, + // Task 2: M0, Dur 10 + {.compatible_machine = {0}, + .duration_for_machine = {10}, + .tasks_that_must_complete_before_this = {}, + .min_start = 0}, + // Task 3: M1, Dur 10. Depends on Task 2 + {.compatible_machine = {1}, + .duration_for_machine = {10}, + .tasks_that_must_complete_before_this = {2}, + .min_start = 0}}; + + // 2. Setup the current selected alternatives for the test + current_machines_ = {0, 0, 0, 1}; + current_durations_ = {10, 10, 10, 10}; + + // 3. Setup the Initial Sequence (M0: 0->1->2, M1: 3) + machine_tasks_.ResetFromPairs( + std::vector>({{0, 0}, {0, 1}, {0, 2}, {1, 3}})); + } + + struct SchedLSTest : public SchedulingLocalSearch { + public: + explicit SchedLSTest(const SchedulingProblem& problem) + : SchedulingLocalSearch(problem) {} + + using SchedulingLocalSearch::AnalyzeSchedule; + using SchedulingLocalSearch::BuildInitialMachineSequences; + using SchedulingLocalSearch::ComputeDynamicState; + using SchedulingLocalSearch::EstimateMakespanForInsert; + using SchedulingLocalSearch::GenerateN8Moves; + using SchedulingLocalSearch::InsertMove; + using SchedulingLocalSearch::MoveEvaluationScratch; + using SchedulingLocalSearch::ScheduleAnalysis; + using SchedulingLocalSearch::SolverState; + }; + + SchedulingProblem problem_; + CompactVectorVector machine_tasks_; + std::vector current_machines_; + std::vector current_durations_; +}; + +TEST_F(JSSPSolverTest, AnalyzeScheduleComputesCorrectGraph) { + const SchedLSTest solver(problem_); + SchedLSTest::ScheduleAnalysis analysis(problem_); + solver.AnalyzeSchedule(machine_tasks_, current_durations_, &analysis); + + // Check Makespan + EXPECT_EQ(analysis.makespan, 40); + + // Check Heads (Earliest Start Times) + EXPECT_THAT(analysis.start_mins, ElementsAre(0, 10, 20, 30)); + + // Check Critical Path + EXPECT_THAT(analysis.critical_path, ElementsAre(0, 1, 2, 3)); + + // Check Topo Order + EXPECT_THAT(analysis.topo_order[3], Eq(3)); + EXPECT_THAT(absl::MakeSpan(analysis.topo_order).subspan(0, 3), + UnorderedElementsAre(0, 1, 2)); +} + +TEST_F(JSSPSolverTest, ComputeDynamicStateBuildsCorrectTailsAndLookups) { + const SchedLSTest solver(problem_); + SchedLSTest::ScheduleAnalysis analysis(problem_); + solver.AnalyzeSchedule(machine_tasks_, current_durations_, &analysis); + const SchedLSTest::SolverState state = solver.ComputeDynamicState( + machine_tasks_, analysis.topo_order, current_durations_); + + // Check Tails + EXPECT_THAT(state.tails, ElementsAre(30, 20, 10, 0)); + + // Check Machine Lookups + EXPECT_THAT(state.prev_on_machine, ElementsAre(-1, 0, 1, -1)); + EXPECT_THAT(state.next_on_machine, ElementsAre(1, 2, -1, -1)); + + // Check Position Lookup + EXPECT_THAT(state.position_in_machine, ElementsAre(0, 1, 2, 0)); +} + +TEST_F(JSSPSolverTest, GenerateN8MovesExtractsValidInsertions) { + const SchedLSTest solver(problem_); + SchedLSTest::ScheduleAnalysis analysis(problem_); + solver.AnalyzeSchedule(machine_tasks_, current_durations_, &analysis); + const SchedLSTest::SolverState state = solver.ComputeDynamicState( + machine_tasks_, analysis.topo_order, current_durations_); + + const std::vector moves = solver.GenerateN8Moves( + analysis.critical_path, state.prev_on_machine, state.next_on_machine, + analysis.start_mins, state.tails, current_machines_, current_durations_); + + // The critical path is [0, 1, 2, 3]. + // Blocks: Block 0 is [0, 1, 2] on M0. Block 1 is [3] on M1. + // Block 0 is the FIRST block, so it tests Last-Op / Rightward rules. + // Valid unique moves generated for Block 0: + // 1. Task 0 moving right after last_op (2). + // 2. Task 1 moving right after last_op (2). + ASSERT_EQ(moves.size(), 2); + + // Task 0 placed after Task 2 + EXPECT_EQ(moves[0].task, 0); + EXPECT_EQ(moves[0].target_task, 2); + EXPECT_EQ(moves[0].place_after, true); + + // Task 1 placed after Task 2 + EXPECT_EQ(moves[1].task, 1); + EXPECT_EQ(moves[1].target_task, 2); + EXPECT_EQ(moves[1].place_after, true); +} + +TEST_F(JSSPSolverTest, EstimateMakespanForInsertAccuratelyPredictsImprovement) { + const SchedLSTest solver(problem_); + SchedLSTest::ScheduleAnalysis analysis(problem_); + solver.AnalyzeSchedule(machine_tasks_, current_durations_, &analysis); + const SchedLSTest::SolverState state = solver.ComputeDynamicState( + machine_tasks_, analysis.topo_order, current_durations_); + + // We propose taking Task 1 and inserting it AFTER Task 2. + const SchedLSTest::InsertMove insert_move{1, 2, true}; + SchedLSTest::MoveEvaluationScratch scratch; + + const IntegerValue estimated_makespan = solver.EstimateMakespanForInsert( + insert_move, analysis.start_mins, state.tails, state.prev_on_machine, + state.next_on_machine, current_durations_, &scratch); + + // Moving T1 after T2 pushes T2 to 10, T3 to 20(->30), T1 to 20(->30). + EXPECT_EQ(estimated_makespan, 30); +} + +TEST_F(JSSPSolverTest, BuildInitialMachineSequencesHandlesDirtyHints) { + // Let's provide a completely non-canonical hint: + const std::vector dirty_hint = { + 500, // Task 0 starts at 500 + 1000, // Task 1 starts at 1000 + 50, // Task 2 starts at 50 + 200 // Task 3 starts at 200 (Machine 1) + }; + + const SchedLSTest solver(problem_); + const CompactVectorVector generated_sequences = + solver.BuildInitialMachineSequences(dirty_hint, current_machines_); + + // We expect exactly 2 machines to be populated + ASSERT_EQ(generated_sequences.size(), 2); + + // Machine 0 should extract the sequence purely based on the sorted start + // times: T2 (50) comes first, then T0 (500), then T1 (1000). + EXPECT_THAT(generated_sequences[0], ElementsAre(2, 0, 1)); + + // Machine 1 only has Task 3. + EXPECT_THAT(generated_sequences[1], ElementsAre(3)); +} + +} // namespace +} // namespace sat +} // namespace operations_research diff --git a/ortools/sat/scheduling_model.cc b/ortools/sat/scheduling_model.cc index fcfa96da949..c6f0607cc57 100644 --- a/ortools/sat/scheduling_model.cc +++ b/ortools/sat/scheduling_model.cc @@ -15,24 +15,32 @@ #include #include -#include -#include +#include #include +#include #include #include #include -#include "absl/container/btree_map.h" +#include "absl/algorithm/container.h" #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/graph_base/connected_components.h" +#include "ortools/graph_base/topologicalsorter.h" #include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/cp_model_utils.h" +#include "ortools/sat/diffn_util.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/util.h" +#include "ortools/util/bitset.h" namespace operations_research { namespace sat { @@ -65,240 +73,541 @@ int64_t GetExprMin(const LinearExpressionProto& expr, return val; } -} // namespace +struct IntervalAffine { + int interval_index; + SchedulingProblemAndMapping::AffineExpr affine; -SchedulingProblemAndMapping BuildSchedulingProblemAndMapping( - const std::vector>& machine_to_tasks, - const std::vector>& task_precedences, - const BestBinaryRelationBounds& vars_precedences, - const CpModelProto& model_proto) { - // Detect the makespan. + bool operator<(const IntervalAffine& o) const { + return std::tie(affine.var, affine.coeff, interval_index, affine.offset) < + std::tie(o.affine.var, o.affine.coeff, o.interval_index, + o.affine.offset); + } - // First, let's check all the intervals that are not succeeded by any other - // interval. - absl::flat_hash_set seen_tasks; - for (const auto& precedence : task_precedences) { - seen_tasks.insert(precedence.first); + bool operator==(const IntervalAffine& o) const { + return std::tie(affine.var, affine.coeff, interval_index, affine.offset) == + std::tie(o.affine.var, o.affine.coeff, o.interval_index, + o.affine.offset); } - std::vector terminal_intervals; - for (const std::vector& machine_intervals : machine_to_tasks) { - for (int interval_idx : machine_intervals) { - if (seen_tasks.insert(interval_idx).second) { - terminal_intervals.push_back(interval_idx); +}; + +// Detects when the makespan is modeled as a physical "barrier" interval at the +// end of a NoOverlap constraint. Returns the interval indexes (to skip during +// task building) alongside the affine expression of the makespan. +std::vector DetectImplicitBarrierMakespans( + const CpModelProto& model_proto) { + int64_t horizon = kint64min; + for (const auto& var : model_proto.variables()) { + if (!var.domain().empty()) { + horizon = std::max(horizon, var.domain(var.domain().size() - 1)); + } + } + + std::vector barriers; + for (int i = 0; i < model_proto.constraints_size(); ++i) { + const auto& ct = model_proto.constraints(i); + if (!ct.has_no_overlap()) continue; + + for (int interval_idx : ct.no_overlap().intervals()) { + const auto& interval_ct = model_proto.constraints(interval_idx); + if (!interval_ct.enforcement_literal().empty()) continue; + + const auto& interval = interval_ct.interval(); + if (GetExprMin(interval.end(), model_proto) >= horizon && + GetExprMin(interval.size(), model_proto) > 0) { + if (interval.start().vars().size() == 1) { + barriers.push_back( + IntervalAffine{interval_idx, + {.var = interval.start().vars(0), + .coeff = interval.start().coeffs(0), + .offset = interval.start().offset()}}); + } } } } + return barriers; +} + +struct VariableBoundsGroup { + int var; + int64_t coeff; + int64_t max_offset; + std::vector bounded_intervals; +}; - VLOG(2) << "Detected " << terminal_intervals.size() - << " intervals that can be last."; +// Detects a makespan as a variable that is made to precede the end of all tasks +// by using precedence constraints. +std::optional +DetectMakespanFromPrecedenceGraph( + const std::vector& grouped_bounds, + const SchedulingProblemAndMapping& problem_and_mapping, + int64_t problem_start) { + SparseBitset visited(problem_and_mapping.problem.tasks.size()); + std::vector stack; + for (const VariableBoundsGroup& group : grouped_bounds) { + if (group.coeff <= 0) continue; + + std::vector explicitly_bounded_tasks; + for (int t = 0; t < problem_and_mapping.problem.tasks.size(); ++t) { + const auto& task_intervals = problem_and_mapping.task_to_intervals[t]; + + auto is_bounded = [&](int idx) { + return absl::c_binary_search(group.bounded_intervals, idx); + }; + + bool is_task_bounded = false; + if (absl::c_any_of(task_intervals.unconditional_intervals, is_bounded)) { + is_task_bounded = true; + } else if (!task_intervals.alternative_intervals.empty()) { + is_task_bounded = true; + for (int alt : task_intervals.alternative_intervals) { + if (!is_bounded(alt)) { + is_task_bounded = false; + break; + } + } + } + + if (is_task_bounded) explicitly_bounded_tasks.push_back(t); + } + + if (explicitly_bounded_tasks.empty()) continue; + + visited.ResetAllToFalse(); + stack = explicitly_bounded_tasks; + for (int t : stack) visited.Set(t); + + int reached_count = 0; + while (!stack.empty()) { + const int curr = stack.back(); + stack.pop_back(); + reached_count++; + + for (int prev : problem_and_mapping.problem.tasks[curr] + .tasks_that_must_complete_before_this) { + if (!visited[prev]) { + visited.Set(prev); + stack.push_back(prev); + } + } + } + + if (reached_count == problem_and_mapping.problem.tasks.size()) { + const int64_t exact_offset = + group.max_offset == kint64min ? problem_start : group.max_offset; + + return SchedulingProblemAndMapping::AffineExpr{ + .var = group.var, .coeff = group.coeff, .offset = exact_offset}; + } + } + return std::nullopt; +} + +// Parses CP-SAT constraints and returns a flat, grouped structure of all +// variables acting as an upper bound to intervals. +std::vector ExtractUpperBounds( + const CpModelProto& model_proto, + const BestBinaryRelationBounds& vars_precedences, + const absl::flat_hash_map& interval_to_machine) { CompactVectorVector intervals_by_end_var; { - std::vector> edges; - for (const int interval_idx : terminal_intervals) { - const IntervalConstraintProto& interval = - model_proto.constraints(interval_idx).interval(); + CompactVectorVectorBuilder builder; + builder.ReserveNumItems(interval_to_machine.size() * 2); + for (int i = 0; i < model_proto.constraints().size(); ++i) { + if (!model_proto.constraints(i).has_interval()) continue; + if (!model_proto.constraints(i).enforcement_literal().empty()) continue; + + const auto& interval = model_proto.constraints(i).interval(); if (interval.end().vars().size() == 1) { - const IntegerVariable end_var = - IntegerVariable(2 * interval.end().vars(0)); - if (interval.end().coeffs(0) > 0) { - edges.push_back({end_var, interval_idx}); - } else { - edges.push_back({NegationOf(end_var), interval_idx}); - } + const IntegerVariable end_var(2 * interval.end().vars(0)); + if (interval.end().coeffs(0) > 0) + builder.Add(end_var, i); + else + builder.Add(NegationOf(end_var), i); } } - intervals_by_end_var.ResetFromPairs(edges, model_proto.variables().size()); + intervals_by_end_var.ResetFromBuilder(builder, + model_proto.variables().size()); } - struct Bounds { - // intervals[interval_idx].end() <= var * coeff + offset - int interval_idx; - int var; - int64_t coeff; - int64_t offset; - }; - std::vector bounds; - for (const std::pair& bound : + std::vector raw_bounds; + for (const auto& [expr, offset] : vars_precedences.GetSortedNonTrivialUpperBounds()) { - if (bound.first.vars[0] == kNoIntegerVariable || - bound.first.vars[1] == kNoIntegerVariable) { + if (expr.vars[0] == kNoIntegerVariable || + expr.vars[1] == kNoIntegerVariable) { continue; } for (int i = 0; i < 2; ++i) { - const IntegerVariable var = bound.first.vars[i]; + const IntegerVariable var = expr.vars[i]; const int proto_var = GetPositiveOnlyIndex(var).value(); - const IntegerValue var_proto_coeff = VariableIsPositive(var) - ? bound.first.coeffs[i] - : -bound.first.coeffs[i]; - const IntegerVariable other_var = bound.first.vars[1 - i]; - const IntegerValue other_coeff = bound.first.coeffs[i]; + const IntegerValue var_proto_coeff = + VariableIsPositive(var) ? expr.coeffs[i] : -expr.coeffs[i]; + + const IntegerVariable other_var = expr.vars[1 - i]; + const IntegerValue other_coeff = expr.coeffs[1 - i]; const int other_proto_var = GetPositiveOnlyIndex(other_var).value(); const int64_t other_proto_coeff = VariableIsPositive(other_var) ? other_coeff.value() : -other_coeff.value(); + if (var >= intervals_by_end_var.size()) continue; + for (int interval_idx : intervals_by_end_var[var]) { - const IntervalConstraintProto& interval = - model_proto.constraints(interval_idx).interval(); - if (!model_proto.constraints(interval_idx) - .enforcement_literal() - .empty()) { - continue; - } + const auto& interval = model_proto.constraints(interval_idx).interval(); if (interval.end().vars().size() != 1) continue; + const int64_t interval_end_coeff = interval.end().coeffs(0); - if (interval_end_coeff % var_proto_coeff != 0) continue; + if (interval_end_coeff % var_proto_coeff.value() != 0) continue; + const int64_t factor = interval_end_coeff / var_proto_coeff.value(); DCHECK_EQ(interval.end().vars(0), proto_var); - bounds.push_back( - Bounds{interval_idx, other_proto_var, -other_proto_coeff * factor, - bound.second.value() * factor + interval.end().offset()}); + + raw_bounds.push_back(IntervalAffine{ + interval_idx, + {.var = other_proto_var, + .coeff = -other_proto_coeff * factor, + .offset = offset.value() * factor + interval.end().offset()}}); } } } - absl::flat_hash_map, int> - potential_var_coeff_bounds_to_interval_counts; - absl::flat_hash_set> seen_bounds; - for (const Bounds& bound : bounds) { - if (seen_bounds.insert({bound.var, bound.coeff, bound.interval_idx}) - .second) { - potential_var_coeff_bounds_to_interval_counts[{bound.var, bound.coeff}]++; + + gtl::STLSortAndRemoveDuplicates(&raw_bounds); + + // Group all intervals bounded by a specific variable + std::vector groups; + for (const IntervalAffine& bound : raw_bounds) { + if (groups.empty() || groups.back().var != bound.affine.var || + groups.back().coeff != bound.affine.coeff) { + groups.push_back({bound.affine.var, bound.affine.coeff, kint64min, {}}); + } + + // Only push unique interval indices into the contiguous group block + if (groups.back().bounded_intervals.empty() || + groups.back().bounded_intervals.back() != bound.interval_index) { + groups.back().bounded_intervals.push_back(bound.interval_index); } + + groups.back().max_offset = + std::max(groups.back().max_offset, bound.affine.offset); } - VLOG(2) << "Detected " << bounds.size() << " upper bounds."; - std::optional> makespan_var_coeff; - for (const auto& [var_coeff, count] : - potential_var_coeff_bounds_to_interval_counts) { - if (count == terminal_intervals.size()) { - makespan_var_coeff = var_coeff; + + return groups; +} + +// Helper to verify if two CP-SAT linear expressions are the same. +bool IsSameExpr(const LinearExpressionProto& a, + const LinearExpressionProto& b) { + if (a.offset() != b.offset()) return false; + if (a.vars_size() != b.vars_size()) return false; + for (int i = 0; i < a.vars_size(); ++i) { + if (a.vars(i) != b.vars(i) || a.coeffs(i) != b.coeffs(i)) return false; + } + return true; +} + +absl::flat_hash_map BuildTasksAndIntervalMapping( + const CpModelProto& model_proto, + const std::vector>& task_to_intervals, + const absl::flat_hash_map& interval_to_machine, + const absl::flat_hash_set& barrier_intervals, int64_t problem_start, + SchedulingProblemAndMapping* problem_and_mapping) { + absl::flat_hash_map interval_to_task_index; + + absl::flat_hash_map> start_var_to_intervals; + for (int i = 0; i < model_proto.constraints_size(); ++i) { + const auto& ct = model_proto.constraints(i); + // Unrestricted collection: we discover dummy intervals even if they have + // enforcement literals (critical for optional industrial tasks). + if (ct.has_interval() && ct.interval().start().vars().size() == 1) { + start_var_to_intervals[ct.interval().start().vars(0)].push_back(i); } } - SchedulingProblemAndMapping problem_and_mapping; - problem_and_mapping.problem.type = SchedulingProblem::kSatisfaction; - if (makespan_var_coeff.has_value()) { - int64_t offset = std::numeric_limits::min(); - for (const Bounds& bound : bounds) { - if (bound.var == makespan_var_coeff->first && - bound.coeff == makespan_var_coeff->second) { - offset = std::max(offset, bound.offset); - } - } - problem_and_mapping.makespan_expr = {makespan_var_coeff->first, - makespan_var_coeff->second, offset}; - VLOG(2) << "Detected makespan: " << makespan_var_coeff->first << " * " - << makespan_var_coeff->second << " + " << offset; - for (int i = 0; i < model_proto.objective().vars().size(); ++i) { - if (model_proto.objective().vars(i) == makespan_var_coeff->first) { - if (model_proto.objective().coeffs(i) > 0 == - makespan_var_coeff->second > 0) { - problem_and_mapping.problem.type = - SchedulingProblem::kMinimizeMakespan; - VLOG(2) << "Detected minimize makespan."; + + for (const std::vector& alts : task_to_intervals) { + if (absl::c_any_of( + alts, [&](int idx) { return barrier_intervals.contains(idx); })) { + continue; + } + + const int task_idx = problem_and_mapping->problem.tasks.size(); + auto& task = problem_and_mapping->problem.tasks.emplace_back(); + problem_and_mapping->task_to_intervals.push_back({}); + auto& task_intervals = problem_and_mapping->task_to_intervals.back(); + std::vector presence_literals; + + for (int interval_idx : alts) { + interval_to_task_index[interval_idx] = task_idx; + const auto& ct = model_proto.constraints(interval_idx); + + const bool is_conditional = !ct.enforcement_literal().empty(); + if (is_conditional) { + task_intervals.alternative_intervals.push_back(interval_idx); + } else { + task_intervals.unconditional_intervals.push_back(interval_idx); + } + + auto it = interval_to_machine.find(interval_idx); + if (it != interval_to_machine.end()) { + task.compatible_machine.push_back(it->second); + task.duration_for_machine.push_back(std::max( + int64_t{0}, GetExprMin(ct.interval().size(), model_proto))); + presence_literals.push_back(is_conditional ? ct.enforcement_literal(0) + : kint32max); + } + } + problem_and_mapping->task_to_presence_literals.push_back(presence_literals); + + const auto& first_interval = model_proto.constraints(alts[0]).interval(); + task.min_start = + std::max(int64_t{0}, GetExprMin(first_interval.start(), model_proto) - + problem_start); + + if (first_interval.start().vars().size() == 1) { + int base_start_var = first_interval.start().vars(0); + problem_and_mapping->task_to_start_time_model_var.push_back( + {base_start_var, first_interval.start().coeffs(0), + first_interval.start().offset() - problem_start}); + + if (auto it = start_var_to_intervals.find(base_start_var); + it != start_var_to_intervals.end()) { + for (int i : it->second) { + if (!barrier_intervals.contains(i) && + absl::c_find(alts, i) == alts.end()) { + // GUARD 1: Never claim a physical machine interval + if (interval_to_machine.contains(i)) continue; + + // GUARD 2: Restored from your original code! Never steal an + // interval already claimed by another task. Prevents industrial + // cannibalization. + if (interval_to_task_index.contains(i)) continue; + + const auto& dummy_interval = model_proto.constraints(i).interval(); + + // GUARD 3: Only check the start expression. CP-SAT gives main + // intervals distinct end variables, so we cannot check + // IsSameExpr(end). + if (IsSameExpr(dummy_interval.start(), first_interval.start())) { + task_intervals.unconditional_intervals.push_back(i); + interval_to_task_index[i] = task_idx; + } + } } } + } else if (first_interval.start().vars().empty()) { + problem_and_mapping->task_to_start_time_model_var.push_back( + {0, 0, first_interval.start().offset() - problem_start}); + } else { + problem_and_mapping->task_to_start_time_model_var.push_back({0, 1, 0}); } } + return interval_to_task_index; +} - absl::flat_hash_map interval_to_task_index; - for (int m = 0; m < machine_to_tasks.size(); ++m) { - const std::vector& machine_intervals = machine_to_tasks[m]; - for (int interval_idx : machine_intervals) { - const IntervalConstraintProto& interval = - model_proto.constraints(interval_idx).interval(); - const int64_t duration = GetExprMin(interval.size(), model_proto); - if (duration <= 0) { - continue; - } - const int task_idx = problem_and_mapping.problem.tasks.size(); - if (!interval_to_task_index.insert({interval_idx, task_idx}).second) { - // TODO(user): support "recipes" where a task must occupy more than - // one machine to complete. - continue; - } - if (interval.start().vars().size() != 1) { - continue; +// Extracts direct precedence edges between actual scheduled tasks. +// This performs a localized DFS through the interval adjacency graph to route +// seamlessly around any dummy or routing intervals inserted by the presolver. +// Crucially, the search stops immediately upon reaching the next valid task, +// computing a fast graph contraction rather than an expensive $O(N^2)$ full +// transitive closure. +void ExtractDirectTaskPrecedences( + const CpModelProto& model_proto, + const std::vector>& interval_precedences, + const absl::flat_hash_map& interval_to_task_index, + SchedulingProblemAndMapping* problem_and_mapping) { + CompactVectorVectorBuilder builder; + builder.ReserveNumItems(interval_precedences.size()); + for (const auto& precedence : interval_precedences) { + builder.Add(precedence.first, precedence.second); + } + CompactVectorVector interval_adj; + interval_adj.ResetFromBuilder(builder, model_proto.constraints_size()); + + SparseBitset visited(model_proto.constraints_size()); + std::vector stack; + + for (int task_idx = 0; task_idx < problem_and_mapping->problem.tasks.size(); + ++task_idx) { + visited.ResetAllToFalse(); + stack.clear(); + + const auto& task_intervals = + problem_and_mapping->task_to_intervals[task_idx]; + + // Seed DFS from ALL intervals properly mapped to the task + for (int interval_idx : task_intervals.alternative_intervals) { + stack.push_back(interval_idx); + visited.Set(interval_idx); + } + for (const int interval_idx : task_intervals.unconditional_intervals) { + stack.push_back(interval_idx); + visited.Set(interval_idx); + } + + while (!stack.empty()) { + const int curr = stack.back(); + stack.pop_back(); + + for (int next : interval_adj[curr]) { + if (visited[next]) continue; + visited.Set(next); + + auto it = interval_to_task_index.find(next); + if (it != interval_to_task_index.end()) { + int target_task_idx = it->second; + + if (target_task_idx != task_idx) { + problem_and_mapping->problem.tasks[target_task_idx] + .tasks_that_must_complete_before_this.push_back(task_idx); + } + continue; // Stop DFS gracefully at task boundaries + } + + // Unmapped node: Safely traverse it! Since we protected against + // cannibalization, the graph is trustworthy and won't cross-contaminate + // bounds. + stack.push_back(next); } - problem_and_mapping.task_to_start_time_model_var.push_back( - {interval.start().vars(0), interval.start().coeffs(0), - interval.start().offset()}); - SchedulingProblem::Task& task = - problem_and_mapping.problem.tasks.emplace_back(); - task.machine = m; - task.duration = duration; } } - for (const auto& precedence : task_precedences) { - const int from_task = interval_to_task_index[precedence.first]; - const int to_task = interval_to_task_index[precedence.second]; - problem_and_mapping.problem.tasks[to_task] - .tasks_that_must_complete_before_this.push_back(from_task); + + for (int i = 0; i < problem_and_mapping->problem.tasks.size(); ++i) { + gtl::STLSortAndRemoveDuplicates(&problem_and_mapping->problem.tasks[i] + .tasks_that_must_complete_before_this); } - return problem_and_mapping; } -SchedulingRelaxation DetectSchedulingProblems(const CpModelProto& model_proto) { - // We will first detect a single large satisfaction job-shop problem, then - // we will try to detect one or more makespans that appear in the objective - // and split the problem accordingly. - std::vector> machine_to_tasks; - BestBinaryRelationBounds precedences; - for (const ConstraintProto& ct : model_proto.constraints()) { - if (!ct.enforcement_literal().empty()) continue; - if (ct.has_no_overlap()) { - machine_to_tasks.emplace_back(ct.no_overlap().intervals().begin(), - ct.no_overlap().intervals().end()); +} // namespace + +SchedulingProblemAndMapping BuildSchedulingProblemAndMapping( + const std::vector>& machine_to_intervals, + const std::vector>& task_to_intervals, + const std::vector>& interval_precedences, + const BestBinaryRelationBounds& vars_precedences, + const CpModelProto& model_proto) { + const int64_t problem_start = 0; + SchedulingProblemAndMapping problem_and_mapping; + + absl::flat_hash_map interval_to_machine; + for (int m = 0; m < machine_to_intervals.size(); ++m) { + for (int interval_idx : machine_to_intervals[m]) { + interval_to_machine[interval_idx] = m; } - if (ct.has_linear() && ct.linear().vars().size() == 2) { - const LinearConstraintProto& lin = ct.linear(); - const LinearExpression2 expr2 = GetLinearExpression2FromProto( - lin.vars(0), lin.coeffs(0), lin.vars(1), lin.coeffs(1)); - const IntegerValue lb(lin.domain(0)); - const IntegerValue ub(lin.domain(lin.domain().size() - 1)); + } - precedences.Add(expr2, lb, ub); + // --- 1. Detect Implicit Barrier Makespans --- + const std::vector barriers = + DetectImplicitBarrierMakespans(model_proto); + absl::flat_hash_set barrier_indices; + std::optional proven_makespan; + + for (const auto& b : barriers) { + barrier_indices.insert(b.interval_index); + if (!proven_makespan) proven_makespan = b.affine; + } + + // --- 2. Extract Upper Bounds --- + std::vector extracted_bounds = + ExtractUpperBounds(model_proto, vars_precedences, interval_to_machine); + + VLOG(2) << "Detected " << extracted_bounds.size() << " grouped upper bounds."; + + // --- 3. Build Tasks and Interval Mapping --- + absl::flat_hash_map interval_to_task_index = + BuildTasksAndIntervalMapping(model_proto, task_to_intervals, + interval_to_machine, barrier_indices, + problem_start, &problem_and_mapping); + + // --- 4. Extract Direct Task Precedences --- + ExtractDirectTaskPrecedences(model_proto, interval_precedences, + interval_to_task_index, &problem_and_mapping); + + // --- 5. Fallback: Precedence Graph Makespan Detection --- + if (!proven_makespan.has_value()) { + if (auto graph_result = DetectMakespanFromPrecedenceGraph( + extracted_bounds, problem_and_mapping, problem_start)) { + proven_makespan = *graph_result; } } + // --- 6. Objective Classification & Mapping --- + bool makespan_is_minimized = false; + + if (proven_makespan.has_value() && model_proto.has_objective()) { + for (int i = 0; i < model_proto.objective().vars().size(); ++i) { + const int obj_var = model_proto.objective().vars(i); + const int64_t obj_coeff = model_proto.objective().coeffs(i); + + if (obj_var == proven_makespan->var) { + makespan_is_minimized = + ((obj_coeff > 0) == (proven_makespan->coeff > 0)); + break; + } + } + } + + if (proven_makespan.has_value()) { + problem_and_mapping.makespan_expr = *proven_makespan; + problem_and_mapping.makespan_expr->offset -= problem_start; + + problem_and_mapping.problem.type = + makespan_is_minimized ? SchedulingProblem::kMinimizeMakespan + : SchedulingProblem::kSatisfaction; + + VLOG(2) << "Makespan detected: I" << proven_makespan->var << " * " + << proven_makespan->coeff << " + " + << proven_makespan->offset - problem_start + << (makespan_is_minimized ? " (minimized)" + : " (not on the objective)"); + } else { + VLOG(2) << "Could not map any variable to a makespan."; + problem_and_mapping.problem.type = SchedulingProblem::kSatisfaction; + } + + return problem_and_mapping; +} + +std::vector> DetectIntervalPrecedences( + const CpModelProto& model_proto, + const BestBinaryRelationBounds& precedences, + absl::Span interval_indices) { CompactVectorVector intervals_by_end_var; { - std::vector> edges; - for (const std::vector& intervals : machine_to_tasks) { - for (const int interval_idx : intervals) { - const IntervalConstraintProto& interval = - model_proto.constraints(interval_idx).interval(); - if (interval.end().vars().size() != 1) continue; - const IntegerVariable end_var = - IntegerVariable(2 * interval.end().vars(0)); - if (interval.end().coeffs(0) > 0) { - edges.push_back({end_var, interval_idx}); - } else { - edges.push_back({NegationOf(end_var), interval_idx}); - } + CompactVectorVectorBuilder + intervals_by_end_var_builder; + for (const int interval_idx : interval_indices) { + const IntervalConstraintProto& interval = + model_proto.constraints(interval_idx).interval(); + if (interval.end().vars().size() != 1) continue; + const IntegerVariable end_var = + IntegerVariable(2 * interval.end().vars(0)); + if (interval.end().coeffs(0) > 0) { + intervals_by_end_var_builder.Add(end_var, interval_idx); + } else { + intervals_by_end_var_builder.Add(NegationOf(end_var), interval_idx); } } - intervals_by_end_var.ResetFromPairs(edges, model_proto.variables().size()); + intervals_by_end_var.ResetFromBuilder(intervals_by_end_var_builder, + model_proto.variables().size()); } CompactVectorVector intervals_by_start_var; { - std::vector> edges; - for (const std::vector& intervals : machine_to_tasks) { - for (const int interval_idx : intervals) { - const IntervalConstraintProto& interval = - model_proto.constraints(interval_idx).interval(); - if (interval.start().vars().size() != 1) continue; - const IntegerVariable start_var = - IntegerVariable(2 * interval.start().vars(0)); - if (interval.start().coeffs(0) > 0) { - edges.push_back({start_var, interval_idx}); - } else { - edges.push_back({NegationOf(start_var), interval_idx}); - } + CompactVectorVectorBuilder + intervals_by_start_var_builder; + for (const int interval_idx : interval_indices) { + const IntervalConstraintProto& interval = + model_proto.constraints(interval_idx).interval(); + if (interval.start().vars().size() != 1) continue; + const IntegerVariable start_var = + IntegerVariable(2 * interval.start().vars(0)); + if (interval.start().coeffs(0) > 0) { + intervals_by_start_var_builder.Add(start_var, interval_idx); + } else { + intervals_by_start_var_builder.Add(NegationOf(start_var), interval_idx); } } - intervals_by_start_var.ResetFromPairs(edges, - model_proto.variables().size()); + intervals_by_start_var.ResetFromBuilder(intervals_by_start_var_builder, + model_proto.variables().size()); } std::vector> interval_precedences; @@ -322,21 +631,11 @@ SchedulingRelaxation DetectSchedulingProblems(const CpModelProto& model_proto) { const IntervalConstraintProto& interval = model_proto.constraints(interval_idx).interval(); if (interval.start().vars().size() != 1) continue; - if (!model_proto.constraints(interval_idx) - .enforcement_literal() - .empty()) { - continue; - } const LinearExpressionProto& interval_start = interval.start(); for (int other_interval_idx : intervals_by_end_var[other_var]) { if (other_interval_idx == interval_idx) continue; const LinearExpressionProto& other_interval_end = model_proto.constraints(other_interval_idx).interval().end(); - if (!model_proto.constraints(other_interval_idx) - .enforcement_literal() - .empty()) { - continue; - } if (other_interval_end.vars().size() != 1) continue; // Do we know if other_interval.end <= interval.start ? @@ -356,8 +655,141 @@ SchedulingRelaxation DetectSchedulingProblems(const CpModelProto& model_proto) { } } + return interval_precedences; +} + +SchedulingRelaxation DetectSchedulingProblems(const CpModelProto& model_proto) { + // We will first detect a single large satisfaction job-shop problem, then + // we will try to detect one or more makespans that appear in the objective + // and split the problem accordingly. + std::vector> machine_to_intervals; + std::vector interval_indices; + absl::flat_hash_map> start_var_to_intervals; + std::vector all_intervals; + BestBinaryRelationBounds precedences; + + std::vector> exact_ones_groups; + + for (int c = 0; c < model_proto.constraints_size(); ++c) { + const ConstraintProto& ct = model_proto.constraints(c); + if (!ct.enforcement_literal().empty()) continue; + if (ct.has_no_overlap() && !ct.no_overlap().intervals().empty()) { + std::vector& intervals = machine_to_intervals.emplace_back(); + for (int interval_idx : ct.no_overlap().intervals()) { + intervals.push_back(interval_idx); + interval_indices.push_back(interval_idx); + } + } + + if (ct.has_linear() && ct.linear().vars().size() == 2) { + const LinearConstraintProto& lin = ct.linear(); + const LinearExpression2 expr2 = GetLinearExpression2FromProto( + lin.vars(0), lin.coeffs(0), lin.vars(1), lin.coeffs(1)); + const IntegerValue lb(lin.domain(0)); + const IntegerValue ub(lin.domain(lin.domain().size() - 1)); + + precedences.Add(expr2, lb, ub); + } + + if (ct.has_exactly_one()) { + exact_ones_groups.push_back( + std::vector{ct.exactly_one().literals().begin(), + ct.exactly_one().literals().end()}); + } + + if (ct.has_interval()) { + all_intervals.push_back(c); + if (ct.enforcement_literal().empty() && + ct.interval().start().vars().size() == 1) { + start_var_to_intervals[ct.interval().start().vars(0)].push_back(c); + } + } + } + gtl::STLSortAndRemoveDuplicates(&interval_indices); + + absl::flat_hash_map literal_to_interval; + for (int interval_idx : interval_indices) { + const ConstraintProto& ct = model_proto.constraints(interval_idx); + if (ct.enforcement_literal().size() == 1) { + literal_to_interval[ct.enforcement_literal(0)] = interval_idx; + } + } + + std::vector> task_to_intervals; + absl::flat_hash_set interval_in_task; + + for (const auto& group : exact_ones_groups) { + std::vector alts; + for (const int lit : group) { + auto it = literal_to_interval.find(lit); + if (it != literal_to_interval.end()) { + alts.push_back(it->second); + } + } + if (alts.size() > 1) { + bool already_used = false; + for (const int a : alts) { + if (interval_in_task.contains(a)) already_used = true; + } + if (!already_used) { + task_to_intervals.push_back(alts); + for (const int a : alts) interval_in_task.insert(a); + } + } + } + + for (auto& alts : task_to_intervals) { + DCHECK(!alts.empty()); + const auto& first_interval = model_proto.constraints(alts[0]).interval(); + + if (first_interval.start().vars().size() == 1) { + const int base_start_var = first_interval.start().vars(0); + + if (auto it = start_var_to_intervals.find(base_start_var); + it != start_var_to_intervals.end()) { + for (const int i : it->second) { + if (!interval_in_task.contains(i)) { + alts.push_back(i); + interval_in_task.insert(i); + } + } + } + } + } + + // Handle the intervals that appears both with b and ¬b as a task that can be + // executed on two machines. + for (const auto& [lit, interval_idx] : literal_to_interval) { + if (lit < 0) continue; + auto it = literal_to_interval.find(NegatedRef(lit)); + if (it != literal_to_interval.end()) { + int other_idx = it->second; + if (!interval_in_task.contains(interval_idx) && + !interval_in_task.contains(other_idx)) { + task_to_intervals.push_back({interval_idx, other_idx}); + interval_in_task.insert(interval_idx); + interval_in_task.insert(other_idx); + } + } + } + + // Handle the intervals that didn't appear in any exact one group as a task + // with a single machine choice. + for (int interval_idx : interval_indices) { + if (!interval_in_task.contains(interval_idx)) { + task_to_intervals.push_back({interval_idx}); + interval_in_task.insert(interval_idx); + } + } + + const std::vector> interval_precedences = + DetectIntervalPrecedences(model_proto, precedences, all_intervals); + VLOG(2) << "Detected " << interval_precedences.size() - << " interval_precedences."; + << " interval precedences."; + if (interval_precedences.empty()) { + return SchedulingRelaxation(); + } // Now that we have the interval precedences, we can imagine a graph where the // nodes are the intervals and the edges are the precedences. We can then @@ -370,19 +802,26 @@ SchedulingRelaxation DetectSchedulingProblems(const CpModelProto& model_proto) { for (const auto& precedence : interval_precedences) { union_find.AddEdge(precedence.first, precedence.second); } - for (const std::vector& intervals : machine_to_tasks) { + for (const std::vector& intervals : machine_to_intervals) { for (int i = 1; i < intervals.size(); ++i) { union_find.AddEdge(intervals[i], intervals[0]); } } + for (const std::vector& alts : task_to_intervals) { + for (int i = 1; i < alts.size(); ++i) { + union_find.AddEdge(alts[i], alts[0]); + } + } + absl::btree_set component_roots; - for (const std::vector& intervals : machine_to_tasks) { + for (const std::vector& intervals : machine_to_intervals) { + DCHECK(!intervals.empty()); component_roots.insert(union_find.FindRoot(intervals[0])); } - VLOG(2) << "Detected " << component_roots.size() << " job-shop sub-problems."; struct PerComponentData { - std::vector> machine_to_tasks; + std::vector> machine_to_intervals; + std::vector> task_to_intervals; std::vector> precedences; }; std::vector per_component_data; @@ -398,47 +837,93 @@ SchedulingRelaxation DetectSchedulingProblems(const CpModelProto& model_proto) { if (it == root_to_index.end()) continue; per_component_data[it->second].precedences.push_back(precedence); } - for (const std::vector& intervals : machine_to_tasks) { + for (const std::vector& intervals : machine_to_intervals) { const int root = union_find.FindRoot(intervals[0]); const auto it = root_to_index.find(root); CHECK(it != root_to_index.end()); - per_component_data[it->second].machine_to_tasks.push_back(intervals); + per_component_data[it->second].machine_to_intervals.push_back(intervals); + } + for (const std::vector& alts : task_to_intervals) { + const int root = union_find.FindRoot(alts[0]); + const auto it = root_to_index.find(root); + if (it != root_to_index.end()) { + per_component_data[it->second].task_to_intervals.push_back(alts); + } } SchedulingRelaxation relaxation; absl::flat_hash_set makespan_vars; + std::vector num_machines_per_problem; + std::vector num_precedences_per_problem; + std::vector num_choices_per_problem; for (int i = 0; i < per_component_data.size(); ++i) { const PerComponentData& data = per_component_data[i]; relaxation.problems_and_mappings.emplace_back( BuildSchedulingProblemAndMapping( - data.machine_to_tasks, data.precedences, precedences, model_proto)); + data.machine_to_intervals, data.task_to_intervals, data.precedences, + precedences, model_proto)); + const SchedulingProblemAndMapping& problem_and_mapping = + relaxation.problems_and_mappings.back(); + + int num_machines = 0; + int num_precedences = 0; + int num_choices = 0; + for (const auto& task : problem_and_mapping.problem.tasks) { + for (int m : task.compatible_machine) { + num_machines = std::max(num_machines, m + 1); + } + num_precedences += task.tasks_that_must_complete_before_this.size(); + num_choices += task.compatible_machine.size(); + } + + if (problem_and_mapping.problem.tasks.size() < 3 || num_machines < 2) { + relaxation.problems_and_mappings.pop_back(); + continue; + } + num_machines_per_problem.push_back(num_machines); + num_precedences_per_problem.push_back(num_precedences); + num_choices_per_problem.push_back(num_choices); if (relaxation.problems_and_mappings.back().makespan_expr.has_value()) { makespan_vars.insert( relaxation.problems_and_mappings.back().makespan_expr->var); } } + + VLOG(2) << "Detected " << relaxation.problems_and_mappings.size() + << " job-shop sub-problems:"; + for (int i = 0; i < relaxation.problems_and_mappings.size(); ++i) { + const SchedulingProblemAndMapping& problem_and_mapping = + relaxation.problems_and_mappings[i]; + VLOG(2) << " " << i << ": " << problem_and_mapping.problem.tasks.size() + << " tasks, " << num_machines_per_problem[i] << " machines, " + << num_precedences_per_problem[i] << " precedences and " + << num_choices_per_problem[i] << " task-machine choices."; + } + + if (relaxation.problems_and_mappings.empty()) { + return relaxation; + } + SchedulingRelaxation::RelaxedObjective& relaxed_objective = relaxation.relaxed_objective; - relaxed_objective.offset = model_proto.objective().offset(); + relaxed_objective.offset = 0; for (int i = 0; i < model_proto.objective().vars().size(); ++i) { - if (makespan_vars.contains(model_proto.objective().vars(i))) { - relaxed_objective.var_in_problem_makespan.push_back( - model_proto.objective().vars(i)); - relaxed_objective.coeff.push_back(model_proto.objective().coeffs(i)); + const int var = model_proto.objective().vars(i); + const int64_t coeff = model_proto.objective().coeffs(i); + if (makespan_vars.contains(var)) { + relaxed_objective.var_in_problem_makespan.push_back(var); + relaxed_objective.coeff.push_back(coeff); } else { - const int var = model_proto.objective().vars(i); - if (model_proto.objective().coeffs(i) > 0) { - relaxed_objective.offset += model_proto.objective().coeffs(i) * - model_proto.variables(var).domain(0); + if (coeff > 0) { + relaxed_objective.offset += + coeff * model_proto.variables(var).domain(0); } else { relaxed_objective.offset += - model_proto.objective().coeffs(i) * - model_proto.variables(var).domain( - model_proto.variables(var).domain().size() - 1); + coeff * model_proto.variables(var).domain( + model_proto.variables(var).domain().size() - 1); } } } - relaxation.relaxed_objective.offset = model_proto.objective().offset(); return relaxation; } @@ -448,59 +933,121 @@ bool VerifySingleSchedulingProblem( VLOG(2) << "Verifying Scheduling problem with " << relaxation.problem.tasks.size() << " tasks."; + auto lit_value = [&solution](int lit) { + DCHECK_NE(lit, kint32max); + const bool is_true = RefIsPositive(lit) ? (solution[lit] == 1) + : (solution[NegatedRef(lit)] == 0); + return is_true; + }; + // First, check task precedences. for (int task_idx = 0; task_idx < relaxation.problem.tasks.size(); ++task_idx) { const SchedulingProblem::Task& task = relaxation.problem.tasks[task_idx]; + + // Skip if the task is not present. + bool is_present = false; + int active_machine_idx = 0; + for (int a = 0; a < task.compatible_machine.size(); ++a) { + int lit = relaxation.task_to_presence_literals[task_idx][a]; + if (lit == kint32max || lit_value(lit)) { + active_machine_idx = a; + is_present = true; + break; + } + } + if (!is_present) continue; + const int start_time_var = relaxation.task_to_start_time_model_var[task_idx].var; const int64_t start_time = solution[start_time_var] * relaxation.task_to_start_time_model_var[task_idx].coeff + relaxation.task_to_start_time_model_var[task_idx].offset; + for (int before_task : task.tasks_that_must_complete_before_this) { + bool before_is_present = false; + int before_active_machine_idx = 0; + for (int a = 0; + a < relaxation.problem.tasks[before_task].compatible_machine.size(); + ++a) { + int lit = relaxation.task_to_presence_literals[before_task][a]; + if (lit == kint32max || lit_value(lit)) { + before_active_machine_idx = a; + before_is_present = true; + break; + } + } + if (!before_is_present) continue; + const int before_task_end_var = relaxation.task_to_start_time_model_var[before_task].var; const int64_t before_task_end_time = solution[before_task_end_var] * relaxation.task_to_start_time_model_var[before_task].coeff + relaxation.task_to_start_time_model_var[before_task].offset + - relaxation.problem.tasks[before_task].duration; + relaxation.problem.tasks[before_task] + .duration_for_machine[before_active_machine_idx]; if (start_time < before_task_end_time) { + VLOG(2) << "Task " << relaxation.problem.tasks[before_task] + << " does not complete before task " << task; return false; } } } + int num_machines = 0; + for (const auto& t : relaxation.problem.tasks) { + for (int m : t.compatible_machine) { + num_machines = std::max(num_machines, m + 1); + } + } + // Now, check that at no time a machine is used by more than one task. - absl::flat_hash_map> - machine_to_intervals; + std::vector>> machine_intervals( + num_machines); for (int task_idx = 0; task_idx < relaxation.problem.tasks.size(); ++task_idx) { const SchedulingProblem::Task& task = relaxation.problem.tasks[task_idx]; + + bool is_present = false; + int active_machine_idx = 0; + for (int a = 0; a < task.compatible_machine.size(); ++a) { + int lit = relaxation.task_to_presence_literals[task_idx][a]; + if (lit == kint32max || lit_value(lit)) { + active_machine_idx = a; + is_present = true; + break; + } + } + if (!is_present) continue; + const int start_time_var = relaxation.task_to_start_time_model_var[task_idx].var; const int64_t start_time = solution[start_time_var] * relaxation.task_to_start_time_model_var[task_idx].coeff + relaxation.task_to_start_time_model_var[task_idx].offset; - const int64_t end_time = start_time + task.duration; - if (!machine_to_intervals[task.machine] - .insert({start_time, end_time}) - .second) { - return false; - } + + const int64_t end_time = + start_time + task.duration_for_machine[active_machine_idx]; + machine_intervals[task.compatible_machine[active_machine_idx]].push_back( + {start_time, end_time}); } - for (const auto& [machine, intervals] : machine_to_intervals) { - auto it = intervals.begin(); - while (it != intervals.end()) { - const int64_t end_time = it->second; - const auto next_it = std::next(it); - if (next_it == intervals.end()) break; - if (next_it->first < end_time) { + + for (int machine = 0; machine < machine_intervals.size(); ++machine) { + auto& intervals = machine_intervals[machine]; + absl::c_sort(intervals); // Sorts by start_time ascending + for (int i = 1; i < intervals.size(); ++i) { + if (intervals[i - 1].second > intervals[i].first) { + VLOG(2) << "Overlap detected on machine " << machine + << " at task with " + "time interval [" + << intervals[i - 1].first << ", " << intervals[i - 1].second + << "] and task with time interval [" << intervals[i].first + << ", " << intervals[i].second << "]"; return false; } - it = next_it; } } VLOG(2) << "Scheduling problem verified."; @@ -511,67 +1058,179 @@ CpModelProto BuildSchedulingModel(const SchedulingProblem& problem) { CpModelProto model_proto; int64_t horizon = 0; int num_machines = 0; + for (const SchedulingProblem::Task& task : problem.tasks) { - num_machines = std::max(num_machines, task.machine + 1); - horizon += task.duration; + for (const int m : task.compatible_machine) { + num_machines = std::max(num_machines, m + 1); + } + int64_t max_dur = 0; + for (const int64_t d : task.duration_for_machine) { + max_dur = std::max(max_dur, d); + } + horizon += max_dur; } + std::vector no_overlap_constraints(num_machines); for (int i = 0; i < num_machines; ++i) { no_overlap_constraints[i] = model_proto.add_constraints()->mutable_no_overlap(); } + const int num_tasks = problem.tasks.size(); - std::vector task_idx_to_interval_idx(num_tasks); + + // Helper struct to unify task end expressions (var + offset) + struct AffineExpr { + int var; + int64_t offset; + }; + + std::vector task_start_vars(num_tasks); + std::vector task_ends(num_tasks); + + // Track predecessors for the makespan "sink task" optimization + std::vector is_predecessor(num_tasks, false); + + // Build variables and interval constraints for each task. for (int i = 0; i < num_tasks; ++i) { - IntegerVariableProto* var = model_proto.add_variables(); - var->add_domain(0); - var->add_domain(horizon); - task_idx_to_interval_idx[i] = model_proto.constraints().size(); - IntervalConstraintProto* interval = - model_proto.add_constraints()->mutable_interval(); - interval->mutable_start()->add_vars(i); - interval->mutable_start()->add_coeffs(1); - interval->mutable_size()->set_offset(problem.tasks[i].duration); - interval->mutable_end()->add_vars(i); - interval->mutable_end()->add_coeffs(1); - interval->mutable_end()->set_offset(problem.tasks[i].duration); - no_overlap_constraints[problem.tasks[i].machine]->add_intervals( - task_idx_to_interval_idx[i]); + const auto& task = problem.tasks[i]; + int num_alts = task.compatible_machine.size(); + + IntegerVariableProto* start_var = model_proto.add_variables(); + start_var->add_domain(0); + start_var->add_domain(horizon); + int start_var_idx = model_proto.variables().size() - 1; + task_start_vars[i] = start_var_idx; + + // Mark any task that must complete *before* this one as a predecessor + for (int p : task.tasks_that_must_complete_before_this) { + is_predecessor[p] = true; + } + + if (num_alts == 1) { + // Unify: End = StartVar + Duration + task_ends[i] = {start_var_idx, task.duration_for_machine[0]}; + + IntervalConstraintProto* interval = + model_proto.add_constraints()->mutable_interval(); + + interval->mutable_start()->add_vars(start_var_idx); + interval->mutable_start()->add_coeffs(1); + + interval->mutable_size()->set_offset(task.duration_for_machine[0]); + + interval->mutable_end()->add_vars(start_var_idx); + interval->mutable_end()->add_coeffs(1); + interval->mutable_end()->set_offset(task.duration_for_machine[0]); + + int interval_idx = model_proto.constraints().size() - 1; + no_overlap_constraints[task.compatible_machine[0]]->add_intervals( + interval_idx); + + } else if (num_alts > 1) { + IntegerVariableProto* end_var = model_proto.add_variables(); + end_var->add_domain(0); + end_var->add_domain(horizon); + int end_var_idx = model_proto.variables().size() - 1; + + // Unify: End = EndVar + 0 + task_ends[i] = {end_var_idx, 0}; + + int64_t min_dur = task.duration_for_machine[0]; + int64_t max_dur = task.duration_for_machine[0]; + for (int64_t d : task.duration_for_machine) { + min_dur = std::min(min_dur, d); + max_dur = std::max(max_dur, d); + } + + IntegerVariableProto* dur_var = model_proto.add_variables(); + dur_var->add_domain(min_dur); + dur_var->add_domain(max_dur); + int dur_var_idx = model_proto.variables().size() - 1; + + IntervalConstraintProto* main_interval = + model_proto.add_constraints()->mutable_interval(); + main_interval->mutable_start()->add_vars(start_var_idx); + main_interval->mutable_start()->add_coeffs(1); + main_interval->mutable_size()->add_vars(dur_var_idx); + main_interval->mutable_size()->add_coeffs(1); + main_interval->mutable_end()->add_vars(end_var_idx); + main_interval->mutable_end()->add_coeffs(1); + + auto* exactly_one = model_proto.add_constraints()->mutable_exactly_one(); + + for (int a = 0; a < num_alts; ++a) { + IntegerVariableProto* bool_var = model_proto.add_variables(); + bool_var->add_domain(0); + bool_var->add_domain(1); + int bool_var_idx = model_proto.variables().size() - 1; + + exactly_one->add_literals(bool_var_idx); + + ConstraintProto* alt_ct = model_proto.add_constraints(); + alt_ct->add_enforcement_literal(bool_var_idx); + IntervalConstraintProto* alt_interval = alt_ct->mutable_interval(); + alt_interval->mutable_start()->add_vars(start_var_idx); + alt_interval->mutable_start()->add_coeffs(1); + alt_interval->mutable_size()->set_offset(task.duration_for_machine[a]); + alt_interval->mutable_end()->add_vars(end_var_idx); + alt_interval->mutable_end()->add_coeffs(1); + + int alt_interval_idx = model_proto.constraints().size() - 1; + no_overlap_constraints[task.compatible_machine[a]]->add_intervals( + alt_interval_idx); + } + } } + + // Build precedence constraints. for (int task_idx = 0; task_idx < num_tasks; ++task_idx) { - const SchedulingProblem::Task& current_task = problem.tasks[task_idx]; - for (int preceding_task_idx : + const auto& current_task = problem.tasks[task_idx]; + for (int preceding_idx : current_task.tasks_that_must_complete_before_this) { - // preceding_task.start + preceding_task.duration <= current_task.start + // current_start >= preceding_end + // current_start - preceding_end.var >= preceding_end.offset LinearConstraintProto* linear = model_proto.add_constraints()->mutable_linear(); - linear->add_vars(task_idx); - linear->add_coeffs(-1); - linear->add_vars(preceding_task_idx); + + linear->add_vars(task_start_vars[task_idx]); linear->add_coeffs(1); - linear->add_domain(-2 * horizon); - linear->add_domain(-problem.tasks[preceding_task_idx].duration); + linear->add_vars(task_ends[preceding_idx].var); + linear->add_coeffs(-1); + + linear->add_domain(task_ends[preceding_idx].offset); + linear->add_domain(kint64max); } } + + // Build makespan constraints using sink tasks only. const int makespan_var = model_proto.variables().size(); IntegerVariableProto* var = model_proto.add_variables(); var->add_domain(0); var->add_domain(horizon); + for (int task_idx = 0; task_idx < num_tasks; ++task_idx) { - // task.start + task.duration <= makespan + // Optimization: Only constrain tasks that have no successors! + if (is_predecessor[task_idx]) continue; + + // makespan >= task_end + // makespan - task_end.var >= task_end.offset LinearConstraintProto* linear = model_proto.add_constraints()->mutable_linear(); - linear->add_vars(task_idx); - linear->add_coeffs(1); + linear->add_vars(makespan_var); + linear->add_coeffs(1); + linear->add_vars(task_ends[task_idx].var); linear->add_coeffs(-1); - linear->add_domain(-2 * horizon); - linear->add_domain(-problem.tasks[task_idx].duration); + + linear->add_domain(task_ends[task_idx].offset); + linear->add_domain(kint64max); } + if (problem.type == SchedulingProblem::kMinimizeMakespan) { model_proto.mutable_objective()->add_vars(makespan_var); model_proto.mutable_objective()->add_coeffs(1); } + return model_proto; } @@ -589,13 +1248,527 @@ bool VerifySchedulingRelaxation(const SchedulingRelaxation& relaxation, for (int i = 0; i < relaxation.relaxed_objective.var_in_problem_makespan.size(); ++i) { const int var = relaxation.relaxed_objective.var_in_problem_makespan[i]; - const int64_t coeff = relaxation.relaxed_objective.coeff[i]; - *relaxed_objective_value += coeff * solution[var]; + const int64_t global_coeff = relaxation.relaxed_objective.coeff[i]; + + int64_t var_lower_bound = kint64min; + + for (const auto& problem_and_mapping : relaxation.problems_and_mappings) { + if (!problem_and_mapping.makespan_expr.has_value() || + problem_and_mapping.makespan_expr->var != var) { + continue; + } + // 1. Calculate the true local makespan directly from the tasks + int64_t problem_makespan = 0; + for (int task_idx = 0; + task_idx < problem_and_mapping.problem.tasks.size(); ++task_idx) { + const int start_time_var = + problem_and_mapping.task_to_start_time_model_var[task_idx].var; + const int64_t start_time = + solution[start_time_var] * + problem_and_mapping.task_to_start_time_model_var[task_idx] + .coeff + + problem_and_mapping.task_to_start_time_model_var[task_idx].offset; + + const SchedulingProblem::Task& task = + problem_and_mapping.problem.tasks[task_idx]; + int active_machine_idx = 0; + for (int a = 0; a < task.compatible_machine.size(); ++a) { + int lit = problem_and_mapping.task_to_presence_literals[task_idx][a]; + if (lit == kint32max) { + active_machine_idx = a; + break; + } + const bool is_true = RefIsPositive(lit) + ? (solution[lit] == 1) + : (solution[NegatedRef(lit)] == 0); + if (is_true) { + active_machine_idx = a; + break; + } + } + + int64_t duration = problem_and_mapping.problem.tasks[task_idx] + .duration_for_machine[active_machine_idx]; + problem_makespan = std::max(problem_makespan, start_time + duration); + } + + // 2. Translate the local task makespan to the global variable space. + // global_var * coeff + offset >= problem_makespan + // global_var >= (problem_makespan - offset) / coeff + const int64_t diff = + problem_makespan - problem_and_mapping.makespan_expr->offset; + const int64_t local_coeff = problem_and_mapping.makespan_expr->coeff; + + // Use integer ceiling division to ensure the lower bound remains + // strictly valid (assuming local_coeff > 0 for a minimize makespan + // relation) + const int64_t translated_bound = + (diff > 0) ? (diff + local_coeff - 1) / local_coeff + : diff / local_coeff; + + var_lower_bound = std::max(var_lower_bound, translated_bound); + } + *relaxed_objective_value += global_coeff * var_lower_bound; } VLOG(2) << "Scheduling relaxation verified with objective value: " << *relaxed_objective_value; return true; } +bool ProbableSplitExists(const GraphForPartition& graph) { + const int kSamples = 64; + + if (graph.num_primary_nodes <= kSamples) { + return true; // Too small for the heuristic to be useful. + } + + std::vector primary_topo_indices; + primary_topo_indices.reserve(graph.num_primary_nodes); + for (int i = 0; i < graph.num_nodes; ++i) { + if (graph.topological_order[i] < graph.num_primary_nodes) { + primary_topo_indices.push_back(i); + } + } + + std::vector sampled_nodes; + sampled_nodes.reserve(kSamples); + const int num_p = primary_topo_indices.size(); + + // Pick kNumSamples nodes following the following rules: + // 1. Always pick the first and last primary nodes in topological order. This + // ensures we cannot miss something that always follows or always precedes + // the sampled nodes. + // 2. Sample remaining nodes well distributed across the topological + // order, biased towards lower degree nodes (note that a node of degree + // zero immediately proves the graph is unsplittable). + + sampled_nodes.push_back(primary_topo_indices.front()); + for (int i = 1; i < kSamples - 1; ++i) { + const int start_idx = (i * (num_p - 1)) / (kSamples - 1); + const int end_idx = + std::max(((i + 1) * (num_p - 1)) / (kSamples - 1), start_idx + 1); + + int best_topo = primary_topo_indices[start_idx]; + int min_degree = graph.adj[graph.topological_order[best_topo]].size(); + + for (int j = start_idx + 1; j < end_idx; ++j) { + const int node_topo = primary_topo_indices[j]; + const int node_degree = + graph.adj[graph.topological_order[node_topo]].size(); + if (node_degree < min_degree) { + min_degree = node_degree; + best_topo = node_topo; + } + } + sampled_nodes.push_back(best_topo); + } + sampled_nodes.push_back(primary_topo_indices.back()); + + // Map original node IDs to their new sampled node IDs. + std::vector node_to_sampled(graph.num_nodes, -1); + for (int i = 0; i < kSamples; ++i) { + node_to_sampled[graph.topological_order[sampled_nodes[i]]] = i; + } + + // Build the connectivity of the sampled graph by computing reachability + // within the original graph. + CompactVectorVector sampled_adj; + sampled_adj.reserve(kSamples, kSamples * 4); + std::vector visited; + + for (int i = 0; i < kSamples; ++i) { + sampled_adj.Add({}); + visited.assign(graph.num_nodes, false); + + const int start_node = graph.topological_order[sampled_nodes[i]]; + visited[start_node] = true; + + for (int t = sampled_nodes[i]; t < graph.num_nodes; ++t) { + const int u = graph.topological_order[t]; + if (!visited[u]) continue; + + if (u != start_node && node_to_sampled[u] != -1) { + sampled_adj.AppendToLastVector(node_to_sampled[u]); + continue; // We don't need to build the transitive closure. + } + + for (const int c : graph.adj[u]) { + visited[c] = true; + } + } + } + + std::vector sampled_topo(kSamples); + std::iota(sampled_topo.begin(), sampled_topo.end(), 0); + + GraphForPartition sampled_graph{/*num_nodes=*/kSamples, + /*num_primary_nodes=*/kSamples, + /*adj=*/sampled_adj, + /*topological_order=*/sampled_topo}; + + return PartitionByIncomparabilityExact(sampled_graph).size() > 1; +} + +CompactVectorVector PartitionByIncomparabilityExact( + const GraphForPartition& graph) { + // The first step to solve this problem is to notice that the partitions must + // be contiguous in the topological order. Thus, we need to find right "cut" + // nodes to split the topological order into our partitions. To do that, we + // write a helper lambda that for a given node finds the last node that is + // unreachable from it downstream in the topological order. The important + // trick is that this last unreachable node is also a lower bound of where the + // next cut can be made. The straightforward solution would be then to iterate + // on current_node starting from 0 and get the maximum of those lower bounds + // until we get to the point where lower_bound >= current_node; when that + // happens, we make a cut and start a new partition. We do something that is + // similar to that, but searching backwards: we start from the first node, + // compute the current lower_bound from its last unreachable primary node, and + // then iterate backward from lower_bound-1 to 0. If we find a new + // lower_bound, jump again to the new lower_bound-1 and iterate backward. To + // avoid exploring the same region twice we keep a vector of + // next_to_check[node] that holds how far backward from node has already been + // checked. The advantage of searching backwards is to optimistically prove + // that the graph in unsplittable with a few large increments of lower_bound + // or, if the graph is splittable, to quickly jump to the next split position. + + if (graph.num_primary_nodes == 0) { + return CompactVectorVector(); + } + // For simplicity, we re-index the graph so its indexes correspond to the + // topological order. + std::vector is_primary(graph.num_nodes, false); + int last_primary_node = -1; + CompactVectorVector adj; // Adjacency list indexed by topological order. + { + std::vector inverse_topo(graph.num_nodes); + for (int i = 0; i < graph.num_nodes; ++i) { + const bool is_primary_node = + graph.topological_order[i] < graph.num_primary_nodes; + is_primary[i] = is_primary_node; + inverse_topo[graph.topological_order[i]] = i; + + if (is_primary_node) { + last_primary_node = i; + } + } + + adj.reserve(graph.num_nodes, graph.adj.num_entries()); + for (int i = 0; i < graph.num_nodes; ++i) { + adj.Add({}); + for (const int c : graph.adj[graph.topological_order[i]]) { + adj.AppendToLastVector(inverse_topo[c]); + } + } + } + + std::vector is_reachable(graph.num_nodes, false); + int work_done = 0; + + // Computes the last unreachable primary node starting from `node` and + // assuming all primary nodes in the range [node, assume_reachable_up_to] can + // reach the whole graph below them in the topological order. Returns `node` + // if it can reach the whole graph below it or the last unreachable primary + // node otherwise. + auto compute_last_unreachable_primary = [&](int node, + int assume_reachable_up_to) { + is_reachable.assign(graph.num_nodes, false); + is_reachable[node] = true; + + // Populate is_reachable + for (int j = node; j <= last_primary_node; ++j) { + if (!is_reachable[j]) continue; + for (const int c : adj[j]) { + if (c <= assume_reachable_up_to && is_primary[c]) { + // We hit a child that can reach the rest of the graph, so we can + // reach all those nodes from `node`. Stop. + return assume_reachable_up_to; + } + is_reachable[c] = true; + } + work_done += 1 + adj[j].size(); + } + + // Find the last primary among the unreachable. + int last_unreachable_primary = node; + for (int j = last_primary_node; j > node; --j) { + if (!is_reachable[j] && is_primary[j]) { + last_unreachable_primary = j; + break; + } + } + return last_unreachable_primary; + }; + + std::vector primary_to_check; + primary_to_check.reserve(graph.num_primary_nodes); + + CompactVectorVector result; + int last_cut_position = -1; + int start_node = 0; + + // Limit to avoid quadratic explosion on large inputs. + constexpr int kWorkLimit = 100'000'000; + + while (start_node <= last_primary_node) { + // Fast-forward to the next primary node + while (start_node <= last_primary_node && !is_primary[start_node]) { + start_node++; + } + if (start_node > last_primary_node) break; + + int proposed_cut = compute_last_unreachable_primary(start_node, -1); + + // For each partition we start by checking all the primary nodes up to + // proposed_cut. Note that elements in primary_to_check are always in + // increasing order. + primary_to_check.clear(); + for (int i = start_node + 1; i <= proposed_cut; ++i) { + if (is_primary[i]) { + primary_to_check.push_back(i); + } + } + + // Verify backwards down to start_node + while (!primary_to_check.empty() && proposed_cut < last_primary_node) { + const int node_to_check = primary_to_check.back(); + primary_to_check.pop_back(); + + // Note that we already checked all the nodes in the range + // [node_to_check, proposed_cut] since we are iterating backwards from + // proposed_cut. + const int lup = + compute_last_unreachable_primary(node_to_check, proposed_cut); + + if (work_done > kWorkLimit) { + // Jump to the end to write the current partition and stop. + proposed_cut = last_primary_node; + } + + if (lup > proposed_cut) { + // The boundary expanded! Push the newly exposed primary nodes onto the + // stack after the existing ones. + for (int i = proposed_cut + 1; i <= lup; ++i) { + if (is_primary[i]) { + primary_to_check.push_back(i); + } + } + proposed_cut = lup; + } + } + + // Write the partition translating back to original node IDs. + result.Add({}); + for (int i = last_cut_position + 1; i <= proposed_cut; ++i) { + if (is_primary[i]) { + result.AppendToLastVector(graph.topological_order[i]); + } + } + + // Advance the lower_bound and prepare for the next partition. + last_cut_position = proposed_cut; + start_node = proposed_cut + 1; + } + + return result; +} + +// Runs the DAG Incomparability Partitioning heuristics. +CompactVectorVector PartitionByIncomparability( + int num_nodes, int num_primary_nodes, const CompactVectorVector& adj) { + auto maybe_topo = util::graph::FastTopologicalSort(adj); + if (!maybe_topo.ok()) { + VLOG(2) << "PartitionByIncomparability found a cycle!"; + CompactVectorVector trivial_partition; + trivial_partition.Add({}); + for (int i = 0; i < num_primary_nodes; ++i) { + trivial_partition.AppendToLastVector(i); + } + + return trivial_partition; + } + + const GraphForPartition graph{num_nodes, num_primary_nodes, adj, + maybe_topo.value()}; + + // Quick heuristic to avoid expensive quadratic computation for most + // unsplittable graphs. + if (!ProbableSplitExists(graph)) { + VLOG(3) << "Quick heuristic deemed graph unsplittable."; + CompactVectorVector result; + result.Add({}); + for (int i = 0; i < num_primary_nodes; ++i) { + result.AppendToLastVector(i); + } + return result; + } + + const CompactVectorVector result = + PartitionByIncomparabilityExact(graph); + + if (result.size() == 1 && result[0].size() == num_primary_nodes) { + VLOG(3) << "PartitionByIncomparability found graph is non-splittable."; + } else { + VLOG(3) << "PartitionByIncomparability found: [" + << absl::StrJoin(result.AsVectorOfSpan(), ",", + [](std::string* out, absl::Span v) { + absl::StrAppend(out, v.size()); + }) + << "]"; + } + return result; +} + +CompactVectorVector IntervalsNonOverlappingComponents( + absl::Span intervals, + const std::vector>& precedences) { + // We want to map the problem of splitting our intervals into a pure graph + // theory problem. To do that, we augment the precedence graph with extra + // nodes and edges to represent the fact that start and end times imposes a + // precedence on the intervals. + // + // The construction is as follows: + // + // 1. Nodes: + // - [0, N-1]: The actual Intervals. + // - [N, N+K-1]: "Arrival" time nodes (t_in). Represents the exact moment + // of time t. + // - [N+K, N+2K-1]: "Departure" time nodes (t_out). Represents time moving + // forward from t. + // + // 2. The time backbone edges: + // We chain time points together: Arrival(t) -> Departure(t) -> + // Arrival(t+1) + // + // 3. The interval time edges: + // An interval spanning from start_time to end_time is hooked as: + // Departure(start_time) -> [Interval] -> Arrival(end_time) + // + // We need two nodes per time point (start and end) to handle zero-duration + // intervals: a zero-duration interval Z at time t must happen after an + // interval ending at t, and before an interval starting at t. By splitting + // time t into Arr(t) and Dep(t), we can hook zero-duration intervals exactly + // in the middle: + // Arrival(t) -> [Zero-Duration Interval] -> Departure(t) + // + // VISUALIZATION: + // Consider I1 [t0..t1], I2 [t1..t2], and Z [t1..t1] + // + // Arr(t0) -> Dep(t0) ---> Arr(t1) -> Dep(t1) ---> Arr(t2) -> Dep(t2) + // | ^ | ^ | ^ + // | [I1] | | [Z] | | [I2] | + // +------------+ +-------+ +------------+ + // + // Notice the paths: + // - I1 reaches I2 via: I1 -> Arr(t1) -> Dep(t1) -> I2. (Guaranteed A finishes + // before B starts). + // - I1 reaches Z via: I1 -> Arr(t1) -> Z. + // - Z reaches I2 via: Z -> Dep(t1) -> I2. + // - If I3 was [t0..t2], it would bypass t1 entirely, meaning it cannot reach + // I1, and I1 cannot reach it. + + const int N = intervals.size(); + if (N <= 1) { + CompactVectorVector res; + for (const auto& inv : intervals) { + res.Add({inv.index}); + } + return res; + } + + std::vector times; + times.reserve(N * 2); + for (const auto& inv : intervals) { + times.push_back(inv.start.value()); + times.push_back(inv.end.value()); + } + gtl::STLSortAndRemoveDuplicates(×); + + const int K = times.size(); + const int first_t_in_index = N; + const int first_t_out_index = N + K; + const int total_nodes = N + 2 * K; + + absl::flat_hash_map time_to_index(times.size()); + for (int i = 0; i < times.size(); ++i) { + time_to_index[times[i]] = i + first_t_in_index; + } + + CompactVectorVectorBuilder adj; + adj.ReserveNumItems(2 * K + 2 * N + precedences.size()); + + for (int i = 0; i < K; ++i) { + const int t_in_index = first_t_in_index + i; + const int t_out_index = first_t_out_index + i; + adj.Add(t_in_index, t_out_index); + if (i < K - 1) { + const int next_t_in_index = first_t_in_index + i + 1; + adj.Add(t_out_index, next_t_in_index); + } + } + + for (int i = 0; i < N; ++i) { + const bool is_zero_duration = + intervals[i].start.value() == intervals[i].end.value(); + const int t_in_start_index = time_to_index[intervals[i].start.value()]; + const int t_in_end_index = time_to_index[intervals[i].end.value()]; + const int t_out_start_index = t_in_start_index + K; + const int t_out_end_index = t_in_end_index + K; + + if (is_zero_duration) { + adj.Add(t_in_start_index, i); + adj.Add(i, t_out_end_index); + } else { + adj.Add(t_out_start_index, i); + adj.Add(i, t_in_end_index); + } + } + + absl::flat_hash_map id_to_internal; + for (int i = 0; i < N; ++i) { + id_to_internal[intervals[i].index] = i; + } + for (const auto& prec : precedences) { + auto it1 = id_to_internal.find(prec.first); + auto it2 = id_to_internal.find(prec.second); + if (it1 != id_to_internal.end() && it2 != id_to_internal.end()) { + adj.Add(it1->second, it2->second); + } + } + + const CompactVectorVector adj_vec(adj, total_nodes); + + const auto partitions = PartitionByIncomparability(total_nodes, N, adj_vec); + + CompactVectorVector result; + + std::vector mapped; + for (int i = 0; i < partitions.size(); ++i) { + mapped.clear(); + for (const int u : partitions[i]) { + // Only keep nodes that represent intervals + if (u < N) mapped.push_back(intervals[u].index); + } + if (!mapped.empty()) { + result.Add(mapped); + } + } + + if (result.size() == 1 && result[0].size() == N) { + VLOG(3) << "IntervalsNonOverlappingComponents: " << N << " intervals and " + << precedences.size() + << " precedences and the found graph is non-splittable."; + } else { + VLOG(3) << "IntervalsNonOverlappingComponents: " << N << " intervals and " + << precedences.size() << " precedences, components=[" + << absl::StrJoin(result.AsVectorOfSpan(), ",", + [](std::string* out, absl::Span v) { + absl::StrAppend(out, v.size()); + }) + << "]"; + } + + return result; +} + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/scheduling_model.h b/ortools/sat/scheduling_model.h index 1e24bf8d336..fa59f053322 100644 --- a/ortools/sat/scheduling_model.h +++ b/ortools/sat/scheduling_model.h @@ -17,12 +17,17 @@ #include #include #include +#include #include +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "ortools/sat/cp_model.pb.h" +#include "ortools/sat/diffn_util.h" +#include "ortools/sat/integer_base.h" +#include "ortools/sat/util.h" namespace operations_research { namespace sat { @@ -36,9 +41,19 @@ namespace sat { // problem is a start time for each task. struct SchedulingProblem { struct Task { - int machine; - int64_t duration; + std::vector compatible_machine; + std::vector duration_for_machine; std::vector tasks_that_must_complete_before_this; + int64_t min_start = 0; + + template + friend void AbslStringify(Sink& sink, const Task& task) { + absl::Format( + &sink, "Task(machines: [%v], durations: [%v], run_after: [%v])", + absl::StrJoin(task.compatible_machine, ","), + absl::StrJoin(task.duration_for_machine, ","), + absl::StrJoin(task.tasks_that_must_complete_before_this, ",")); + } }; std::vector tasks; @@ -52,11 +67,12 @@ struct SchedulingProblem { : "satisfaction"), absl::StrJoin( problem.tasks, ",", [](std::string* out, const Task& task) { - absl::Format(out, - "Task(machine: %v, duration: %v, run_after: %v)", - task.machine, task.duration, - absl::StrJoin( - task.tasks_that_must_complete_before_this, ",")); + absl::Format( + out, "Task(machines: [%v], durations: [%v], run_after: [%v])", + absl::StrJoin(task.compatible_machine, ","), + absl::StrJoin(task.duration_for_machine, ","), + absl::StrJoin(task.tasks_that_must_complete_before_this, + ",")); })); } }; @@ -67,13 +83,23 @@ struct SchedulingProblem { struct SchedulingProblemAndMapping { SchedulingProblem problem; - struct ShiftedVar { + struct AffineExpr { int var; // The variable in the CpModelProto. int64_t coeff; int64_t offset; }; - std::vector task_to_start_time_model_var; - std::optional makespan_expr; + std::vector task_to_start_time_model_var; + std::optional makespan_expr; + std::vector> task_to_presence_literals; + + struct TaskIntervalVars { + // Conditional intervals that map 1:1 with `compatible_machine`. + std::vector alternative_intervals; + + // Unconditional intervals (e.g., dummies used for precedence routing). + std::vector unconditional_intervals; + }; + std::vector task_to_intervals; }; // A relaxation of the CpModelProto as a set of independent scheduling problems @@ -82,7 +108,10 @@ struct SchedulingRelaxation { std::vector problems_and_mappings; struct RelaxedObjective { - // `var_in_problem_makespan` must match one of the makespan_expr.var. + // `var_in_problem_makespan` must match at least one of the + // makespan_expr.var. If it matches the makespan_expr of several + // problems_and_mappings, one must pick the largest value to get the + // objective value of the relaxation. std::vector var_in_problem_makespan; std::vector coeff; int64_t offset; @@ -107,6 +136,103 @@ bool VerifySchedulingRelaxation(const SchedulingRelaxation& relaxation, absl::Span solution, int64_t* relaxed_objective_value); +// Detects all the precedences between intervals from `precedences`. +// This return pairs of "interval constraint indices" in the given proto where +// we are sure that end(pair.first) <= start(pair.second) in all feasible +// solution. +std::vector> DetectIntervalPrecedences( + const CpModelProto& model_proto, + const BestBinaryRelationBounds& precedences, + absl::Span interval_indices); + +// Splits the set of intervals into components that cannot overlap. More +// precisely, for two components C1 and C2 one of the following must be true: +// - every interval in C1 must end at or before any interval in C2 starts; +// - every interval in C2 must end at or before any interval in C1 starts. +// +// This function takes into account both the `precedences` and the trivial +// precedences from the interval start and end times. It runs in O(num_intervals +// + num_precedences) time for "easy" cases and O(num_intervals * +// (num_precedences) in the worst case. It always returns the largest number of +// components possible. It doesn't do anything if the number of intervals is too +// large to avoid spending too much time in this quadratic algorithm. +CompactVectorVector IntervalsNonOverlappingComponents( + absl::Span intervals, + const std::vector>& precedences); + +// ----------------------------------------------------------------------------- +// Induced DAG Incomparability Partitioning +// ----------------------------------------------------------------------------- +// Given a Directed Acyclic Graph (DAG) G = (V, E) and a contiguous subset of +// "Primary Nodes" P (where P comprises the indices [0, num_primary_nodes - 1]), +// the goal is to partition P into the maximum number of disjoint subsets +// S_1, ..., S_k such that for any two primary vertices u and v belonging to +// DIFFERENT subsets, there exists a directed path between them in G (either +// u ~> v or v ~> u). +// +// Equivalently: If there is no directed path between u and v in either +// direction, they must be placed in the SAME subset. In graph theory terms, +// this is computing the connected components of the induced incomparability +// subgraph on P. +// +// The resulting partition forms a strict total order of subsets, meaning all +// directed paths between distinct subsets flow in a single direction +// (S_i -> S_j where i < j). +// +// === Why the distinction between primary and non-primary nodes? === +// In scheduling, explicitly drawing direct edges between every single pair of +// comparable tasks with incompatible start/end times creates a massive O(N^2) +// edge explosion. To avoid the quadratic size, we introduce auxiliary nodes to +// represent points in time and model the incompatibilities as precedences +// (edges) in the graph. +// +// However, these time nodes exist purely to pass paths forward. Because +// they don't have edges to *every* parallel task happening at the same time, +// the standard graph math sees them as "incomparable" to those parallel tasks. +// If we naively partition the entire graph, these routing nodes act as false +// "glue," accidentally merging completely independent sequences of tasks +// together. +// +// By strictly evaluating incomparability ONLY on the Primary Nodes, we achieve +// the memory savings of the auxiliary routing infrastructure without suffering +// from its mathematical side-effects. +// +// === The Algorithms === +// We built two algorithms to tackle this efficiently: +// 1. ProbableSplitExists(): A fast O(V + E) probabilistic filter that samples a +// subset of primary nodes and computes if the subset can be partitioned to +// quickly detect unsplittable graphs. +// 2. PartitionByIncomparabilityExact(): An exact O(V * E) solver that +// in the worst case computes the reachability for each primary node. +// +// We then combine these into PartitionByIncomparability() which runs the filter +// first and only runs the exact solver when the filter suggests the graph is +// splittable. +// +// Fun fact: if this sounds over-complicated, there is an algorithm [1] that can +// solve the general incomparability problem optimally in O(V + E) time. +// However, it's so complex that not even Gemini had the stomach to implement +// it. +// +// [1] McConnell, Ross M., and Fabien De Montgolfier. "Linear-time modular +// decomposition of directed graphs." Discrete Applied Mathematics 145.2 (2005): +// 198-209. + +// Represents a DAG that has been validated cycle-free and topologically sorted. +struct GraphForPartition { + int num_nodes; + int num_primary_nodes; + const CompactVectorVector& adj; + const std::vector& topological_order; +}; +CompactVectorVector PartitionByIncomparability( + int num_nodes, int num_primary_nodes, const CompactVectorVector& adj); + +// The two underlying algorithms are exposed for testing. +CompactVectorVector PartitionByIncomparabilityExact( + const GraphForPartition& graph); +bool ProbableSplitExists(const GraphForPartition& graph); + } // namespace sat } // namespace operations_research diff --git a/ortools/sat/shaving_solver.cc b/ortools/sat/shaving_solver.cc index 7ad97e7b044..0581f60c013 100644 --- a/ortools/sat/shaving_solver.cc +++ b/ortools/sat/shaving_solver.cc @@ -550,11 +550,9 @@ void VariablesShavingSolver::CopyModelConnectedToVar( return cc_finder.Connected(root_index, ct_to_node(c)); }; - PresolveContext context(local_model, shaving_proto, nullptr); std::vector interval_mapping; - ImportModelAndDomainsWithBasicPresolveIntoContext( - model_proto_, var_domains_, active_constraints, &context, - &interval_mapping); + CopyModelAdvanced(model_proto_, var_domains_, active_constraints, + &interval_mapping, shaving_proto, local_model); // Now copy the ignored constraints "partially". for (const int c : ignored_constraints) { @@ -697,7 +695,6 @@ bool VariablesShavingSolver::ResetAndSolveModel(int64_t task_id, State* state, auto context = std::make_unique(local_model, shaving_proto, &mapping_proto); context->InitializeNewDomains(); - context->UpdateNewConstraintsVariableUsage(); const int num_constraints = shaving_proto->constraints().size(); std::vector useful_interval(num_constraints, false); std::vector no_overalp_2d; diff --git a/ortools/sat/simplification.cc b/ortools/sat/simplification.cc index 0af99dd32e6..865124846ba 100644 --- a/ortools/sat/simplification.cc +++ b/ortools/sat/simplification.cc @@ -16,7 +16,6 @@ #include #include #include -#include #include #include @@ -31,6 +30,7 @@ #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" #include "ortools/base/timer.h" +#include "ortools/base/types.h" #include "ortools/graph_base/strongly_connected_components.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -859,7 +859,7 @@ LiteralIndex SatPresolver::FindLiteralWithShortestOccurrenceListExcluding( const std::vector& clause, Literal to_exclude) { DCHECK(!clause.empty()); LiteralIndex result = kNoLiteralIndex; - int num_occurrences = std::numeric_limits::max(); + int num_occurrences = kint32max; for (const Literal l : clause) { if (l == to_exclude) continue; if (literal_to_clause_sizes_[l] < num_occurrences) { diff --git a/ortools/sat/synchronization.cc b/ortools/sat/synchronization.cc index 8c8d3ef43df..e44eed05cf0 100644 --- a/ortools/sat/synchronization.cc +++ b/ortools/sat/synchronization.cc @@ -52,6 +52,7 @@ #include "ortools/base/log_severity.h" #include "ortools/base/macros/os_support.h" #include "ortools/base/timer.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/integer_base.h" @@ -267,7 +268,7 @@ SharedResponseManager::SharedResponseManager(Model* model) : parameters_(*model->GetOrCreate()), wall_timer_(*model->GetOrCreate()), shared_time_limit_(model->GetOrCreate()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), solution_pool_(parameters_), logger_(model->GetOrCreate()) { bounds_logging_id_ = logger_->GetNewThrottledId(); @@ -529,7 +530,7 @@ IntegerValue SharedResponseManager::GetInnerObjectiveUpperBound() { } void SharedResponseManager::Synchronize() { - solution_pool_.Synchronize(*random_); + solution_pool_.Synchronize(random_); absl::MutexLock mutex_lock(mutex_); synchronized_inner_objective_lower_bound_ = @@ -808,7 +809,7 @@ SharedResponseManager::NewSolution(absl::Span solution_values, // In single thread, no one is synchronizing the solution manager, so we // should do it from here. if (always_synchronize_) { - solution_pool_.Synchronize(*random_); + solution_pool_.Synchronize(random_); first_solution_solvers_should_stop_ = true; } @@ -930,12 +931,10 @@ SolverStatusChangeInfo SharedResponseManager::GetSolverStatusChangeInfo() { SharedBoundsManager::SharedBoundsManager(const CpModelProto& model_proto) : num_variables_(model_proto.variables_size()), model_proto_(model_proto), - lower_bounds_(num_variables_, std::numeric_limits::min()), - upper_bounds_(num_variables_, std::numeric_limits::max()), - synchronized_lower_bounds_(num_variables_, - std::numeric_limits::min()), - synchronized_upper_bounds_(num_variables_, - std::numeric_limits::max()) { + lower_bounds_(num_variables_, kint64min), + upper_bounds_(num_variables_, kint64max), + synchronized_lower_bounds_(num_variables_, kint64min), + synchronized_upper_bounds_(num_variables_, kint64max) { changed_variables_since_last_synchronize_.ClearAndResize(num_variables_); for (int i = 0; i < num_variables_; ++i) { lower_bounds_[i] = model_proto.variables(i).domain(0); @@ -959,18 +958,17 @@ SharedBoundsManager::SharedBoundsManager(const CpModelProto& model_proto) var_to_orbit_index_ = GetOrbits(num_vars, generators); // Fill orbits_. - std::vector keys; - std::vector values; + CompactVectorVectorBuilder orbits_builder; + orbits_builder.ReserveNumItems(num_vars); for (int var = 0; var < num_vars; ++var) { const int orbit_index = var_to_orbit_index_[var]; if (orbit_index == -1) continue; - keys.push_back(orbit_index); - values.push_back(var); + orbits_builder.Add(orbit_index, var); } - if (keys.empty()) return; + orbits_.ResetFromBuilder(orbits_builder); + if (orbits_.empty()) return; has_symmetry_ = true; - orbits_.ResetFromFlatMapping(keys, values); // Fill representative. var_to_representative_.resize(num_vars); @@ -1131,6 +1129,7 @@ void SharedBoundsManager::Synchronize() { } } changed_variables_since_last_synchronize_.ResetAllToFalse(); + synchronized_timestamp_ = timestamp_; } int SharedBoundsManager::RegisterNewId(absl::string_view name) { @@ -1182,7 +1181,7 @@ void SharedBoundsManager::GetChangedBounds( bounds_stats_[id_to_name_[id]].num_imported += variables->size(); } if (timestamp != nullptr) { - *timestamp = timestamp_; + *timestamp = synchronized_timestamp_; } } @@ -1429,6 +1428,13 @@ void SharedClausesManager::AddEdge(LiteralIndex a, LiteralIndex b) { } const LiteralIndex rep_a = GetRepresentative(a); const LiteralIndex rep_b = GetRepresentative(b); + + if (Literal(rep_a) == Literal(rep_b).Negated()) { + // This happens when we just proved a model UNSAT. Do nothing to preserve + // the invariant representative[negated(a)] == negated(representative[a]). + return; + } + // Always use the min as the new parent, in order to guarantee that the // representative of not(a) is the negation of the representative of a. On the // other hand, this does not give the shallowest new tree. This gives a less @@ -1527,12 +1533,13 @@ void SharedClausesManager::LogStatistics(SolverLogger* logger) { absl::c_sort(name_to_table_line); std::vector> table; table.push_back({"Clauses shared", "#Exported", "#Imported", "#BinaryRead", - "#BinaryTotal"}); + "#BinaryTotal", "#EquivTotal"}); for (const auto& [name, exported, imported, binary_read, binary_total] : name_to_table_line) { table.push_back({FormatName(name), FormatCounter(exported), FormatCounter(imported), FormatCounter(binary_read), - FormatCounter(binary_total)}); + FormatCounter(binary_total), + FormatCounter(num_equivalences_)}); } SOLVER_LOG(logger, FormatTable(table)); } diff --git a/ortools/sat/synchronization.h b/ortools/sat/synchronization.h index bd593ee2946..c6cf673158c 100644 --- a/ortools/sat/synchronization.h +++ b/ortools/sat/synchronization.h @@ -44,6 +44,7 @@ #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" #include "ortools/base/timer.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/integer_base.h" #include "ortools/sat/model.h" @@ -134,7 +135,7 @@ class SharedSolutionRepository { std::shared_ptr GetSolution(int index) const; // Returns the rank of the best known solution. If there is no solution, this - // will return std::numeric_limits::max(). + // will return kint64max. int64_t GetBestRank() const; std::vector> GetBestNSolutions(int n) const; @@ -294,10 +295,8 @@ class SharedSolutionPool { // best known solution. We usually never select seeds_[0] but keep it around // for later in case new best solutions are found. absl::Mutex mutex_; - int64_t max_rank_ ABSL_GUARDED_BY(mutex_) = - std::numeric_limits::min(); - int64_t min_rank_ ABSL_GUARDED_BY(mutex_) = - std::numeric_limits::max(); + int64_t max_rank_ ABSL_GUARDED_BY(mutex_) = kint64min; + int64_t min_rank_ ABSL_GUARDED_BY(mutex_) = kint64max; std::vector ranks_; std::vector< std::shared_ptr::Solution>> @@ -586,7 +585,7 @@ class SharedResponseManager { const SatParameters& parameters_; const WallTimer& wall_timer_; ModelSharedTimeLimit* shared_time_limit_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; CpObjectiveProto const* objective_or_null_ = nullptr; mutable absl::Mutex mutex_; @@ -602,18 +601,15 @@ class SharedResponseManager { SharedSolutionPool solution_pool_; // Thread-safe. int num_solutions_ ABSL_GUARDED_BY(mutex_) = 0; - int64_t inner_objective_lower_bound_ ABSL_GUARDED_BY(mutex_) = - std::numeric_limits::min(); - int64_t inner_objective_upper_bound_ ABSL_GUARDED_BY(mutex_) = - std::numeric_limits::max(); - int64_t best_solution_objective_value_ ABSL_GUARDED_BY(mutex_) = - std::numeric_limits::max(); + int64_t inner_objective_lower_bound_ ABSL_GUARDED_BY(mutex_) = kint64min; + int64_t inner_objective_upper_bound_ ABSL_GUARDED_BY(mutex_) = kint64max; + int64_t best_solution_objective_value_ ABSL_GUARDED_BY(mutex_) = kint64max; bool always_synchronize_ ABSL_GUARDED_BY(mutex_) = true; - IntegerValue synchronized_inner_objective_lower_bound_ ABSL_GUARDED_BY( - mutex_) = IntegerValue(std::numeric_limits::min()); - IntegerValue synchronized_inner_objective_upper_bound_ ABSL_GUARDED_BY( - mutex_) = IntegerValue(std::numeric_limits::max()); + IntegerValue synchronized_inner_objective_lower_bound_ + ABSL_GUARDED_BY(mutex_) = IntegerValue(kint64min); + IntegerValue synchronized_inner_objective_upper_bound_ + ABSL_GUARDED_BY(mutex_) = IntegerValue(kint64max); bool update_integral_on_each_change_ ABSL_GUARDED_BY(mutex_) = false; double gap_integral_ ABSL_GUARDED_BY(mutex_) = 0.0; @@ -742,6 +738,7 @@ class SharedBoundsManager { std::vector synchronized_upper_bounds_ ABSL_GUARDED_BY(mutex_); std::deque> id_to_changed_variables_ ABSL_GUARDED_BY(mutex_); + int64_t synchronized_timestamp_ ABSL_GUARDED_BY(mutex_) = 0; std::vector id_to_name_ ABSL_GUARDED_BY(mutex_); @@ -1086,7 +1083,7 @@ SharedSolutionRepository::GetSolution(int i) const { template int64_t SharedSolutionRepository::GetBestRank() const { absl::MutexLock mutex_lock(mutex_); - if (solutions_.empty()) return std::numeric_limits::max(); + if (solutions_.empty()) return kint64max; return solutions_[0]->rank; } @@ -1203,9 +1200,8 @@ void SharedSolutionRepository::Synchronize( } } - const int64_t old_best_rank = solutions_.empty() - ? std::numeric_limits::max() - : solutions_[0]->rank; + const int64_t old_best_rank = + solutions_.empty() ? kint64max : solutions_[0]->rank; solutions_.insert(solutions_.end(), new_solutions_.begin(), new_solutions_.end()); @@ -1283,7 +1279,7 @@ void SharedSolutionRepository::Synchronize( solutions_.resize(new_size); if (VLOG_IS_ON(3)) { - int min_count = std::numeric_limits::max(); + int min_count = kint32max; int max_count = 0; for (const auto& s : solutions_) { CHECK(s != nullptr); diff --git a/ortools/sat/synchronization_test.cc b/ortools/sat/synchronization_test.cc index f8b99c27234..65102bf0a39 100644 --- a/ortools/sat/synchronization_test.cc +++ b/ortools/sat/synchronization_test.cc @@ -25,6 +25,7 @@ #include "gtest/gtest.h" #include "ortools/base/gmock.h" #include "ortools/base/parse_test_proto.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/integer_base.h" @@ -414,10 +415,8 @@ TEST(SharedResponseManagerTest, GapIntegralTest) { // Unknown count as max possible difference. shared_time_limit->AdvanceDeterministicTime(1.0); shared_response->UpdateGapIntegral(); - const double value1 = - 1.0 * - log(1 + 4 * (static_cast(std::numeric_limits::max()) - - static_cast(std::numeric_limits::min()))); + const double value1 = 1.0 * log(1 + 4 * (static_cast(kint64max) - + static_cast(kint64min))); EXPECT_EQ(value1, shared_response->GapIntegral()); // No time, so still same. But the function height will change. diff --git a/ortools/sat/util.cc b/ortools/sat/util.cc index 3e1239719ac..297848810ab 100644 --- a/ortools/sat/util.cc +++ b/ortools/sat/util.cc @@ -36,6 +36,7 @@ #include "ortools/base/log_severity.h" #include "ortools/base/macros/os_support.h" #include "ortools/base/mathutil.h" +#include "ortools/base/types.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" #include "ortools/util/saturated_arithmetic.h" @@ -198,8 +199,8 @@ bool SolveDiophantineEquationOfSizeTwo(int64_t& a, int64_t& b, int64_t& cte, int64_t& x0, int64_t& y0) { CHECK_NE(a, 0); CHECK_NE(b, 0); - CHECK_NE(a, std::numeric_limits::min()); - CHECK_NE(b, std::numeric_limits::min()); + CHECK_NE(a, kint64min); + CHECK_NE(b, kint64min); const int64_t gcd = std::gcd(std::abs(a), std::abs(b)); if (cte % gcd != 0) return false; @@ -231,8 +232,8 @@ bool SolveDiophantineEquationOfSizeTwo(int64_t& a, int64_t& b, int64_t& cte, // - a * x0 <= cte, in this case y0 will not overflow (<= cte). // - a * x0 > cte, in this case y0 will be in (-a, 0]. const absl::int128 r = t / absl::int128{b}; - DCHECK_LE(r, absl::int128{std::numeric_limits::max()}); - DCHECK_GE(r, absl::int128{std::numeric_limits::min()}); + DCHECK_LE(r, absl::int128{kint64max}); + DCHECK_GE(r, absl::int128{kint64min}); y0 = static_cast(r); return true; @@ -700,7 +701,7 @@ BasicKnapsackSolver::Result BasicKnapsackSolver::InternalSolve( const std::vector& prev = var_activity_states_[i - 1]; std::vector& current = var_activity_states_[i]; for (int prev_value = 0; prev_value < num_values; ++prev_value) { - if (prev[prev_value].cost == std::numeric_limits::max()) { + if (prev[prev_value].cost == kint64max) { continue; } for (const int64_t v : domains_[i].Values()) { @@ -719,7 +720,7 @@ BasicKnapsackSolver::Result BasicKnapsackSolver::InternalSolve( Result result; result.solved = true; - int64_t best_cost = std::numeric_limits::max(); + int64_t best_cost = kint64max; int64_t best_activity; for (int v = 0; v < num_values; ++v) { // TODO(user): optimize this? @@ -730,7 +731,7 @@ BasicKnapsackSolver::Result BasicKnapsackSolver::InternalSolve( } } - if (best_cost == std::numeric_limits::max()) { + if (best_cost == kint64max) { result.infeasible = true; return result; } @@ -989,7 +990,7 @@ std::vector FindMostDiverseSubset(int k, int n, if (k == n - 1) { // We just exclude the one closer to all the other. - int64_t worse = std::numeric_limits::max(); + int64_t worse = kint64max; int to_exclude = -1; for (int i = 0; i < n; ++i) { if ((always_pick_mask >> i) & 1) continue; diff --git a/ortools/sat/util.h b/ortools/sat/util.h index dd2a7d9d223..afeaacd8ffb 100644 --- a/ortools/sat/util.h +++ b/ortools/sat/util.h @@ -37,6 +37,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_base.h" #include "ortools/sat/sat_parameters.pb.h" @@ -65,6 +66,26 @@ class IdentityMap { T operator[](T t) const { return t; } }; +template +class CompactVectorVector; + +template +class CompactVectorVectorBuilder { + public: + void Add(const K& key, const V& value); + + void ReserveNumItems(int64_t num_items); + + void Clear(); + + private: + friend class CompactVectorVector; + + std::vector key_buffer_; + std::vector value_buffer_; + K max_key_ = K(0); +}; + // Small utility class to store a vector> where one can only append new // vector and never change previously added ones. This allows to store a static // key -> value(s) mapping. @@ -78,6 +99,12 @@ class CompactVectorVector { public: using value_type = V; + CompactVectorVector() = default; + explicit CompactVectorVector(const CompactVectorVectorBuilder& builder, + int minimum_num_keys = 0) { + ResetFromBuilder(builder, minimum_num_keys); + } + // Size of the "key" space, always in [0, size()). size_t size() const; bool empty() const; @@ -103,6 +130,9 @@ class CompactVectorVector { buffer_.reserve(num_entries); } + void ResetFromBuilder(const CompactVectorVectorBuilder& builder, + int minimum_num_keys = 0); + // Given a flat mapping (keys[i] -> values[i]) with two parallel vectors, not // necessarily sorted by key, regroup the same key so that // CompactVectorVector[key] list all values in the order in which they appear. @@ -168,6 +198,9 @@ class CompactVectorVector { // This will crash if there are more values than before. void ReplaceValuesBySmallerSet(K key, absl::Span values); + // Sorts the values at the given key and removes duplicates. + void SortAndRemoveDuplicateValues(K key); + // Shrinks the inner vector size of the given key. void Shrink(K key, int new_size); @@ -178,9 +211,6 @@ class CompactVectorVector { } private: - // Convert int and StrongInt to normal int. - static int InternalKey(K key); - std::vector starts_; std::vector sizes_; std::vector buffer_; @@ -452,26 +482,12 @@ bool LinearInequalityCanBeReducedWithClosestMultiple( // ones. class ModelRandomGenerator : public absl::BitGenRef { public: - // We seed the strategy at creation only. This should be enough for our use - // case since the SatParameters is set first before the solver is created. We - // also never really need to change the seed afterwards, it is just used to - // diversify solves with identical parameters on different Model objects. - explicit ModelRandomGenerator(const SatParameters& params) - : absl::BitGenRef(deterministic_random_) { - deterministic_random_.seed(params.random_seed()); - if (params.use_absl_random()) { - absl_random_ = absl::BitGen(absl::SeedSeq({params.random_seed()})); - absl::BitGenRef::operator=(absl::BitGenRef(absl_random_)); - } - } - explicit ModelRandomGenerator(const absl::BitGenRef& bit_gen_ref) - : absl::BitGenRef(deterministic_random_) { - absl::BitGenRef::operator=(bit_gen_ref); - } + : absl::BitGenRef(bit_gen_ref) {} explicit ModelRandomGenerator(Model* model) - : ModelRandomGenerator(*model->GetOrCreate()) {} + : absl::BitGenRef( + model->GetOrCreate()->bit_gen_ref()) {} // This is just used to display ABSL_RANDOM_SALT_OVERRIDE in the log so that // it is possible to reproduce a failure more easily while looking at a solver @@ -481,8 +497,30 @@ class ModelRandomGenerator : public absl::BitGenRef { void LogSalt() const {} private: - random_engine_t deterministic_random_; - absl::BitGen absl_random_; + class ModelRandomEngine { + public: + // We seed the strategy at creation only. This should be enough for our use + // case since the SatParameters is set first before the solver is created. + // We also never really need to change the seed afterwards, it is just used + // to diversify solves with identical parameters on different Model objects. + explicit ModelRandomEngine(Model* model) { + const SatParameters& params = *model->GetOrCreate(); + if (params.use_absl_random()) { + absl_random_ = absl::BitGen(absl::SeedSeq({params.random_seed()})); + absl_bit_gen_ref_ = absl::BitGenRef(absl_random_); + } else { + deterministic_random_.seed(params.random_seed()); + absl_bit_gen_ref_ = absl::BitGenRef(deterministic_random_); + } + } + + absl::BitGenRef bit_gen_ref() const { return absl_bit_gen_ref_; } + + private: + random_engine_t deterministic_random_; + absl::BitGen absl_random_; + absl::BitGenRef absl_bit_gen_ref_ = absl::BitGenRef(deterministic_random_); + }; }; // The model "singleton" shared time limit. @@ -622,7 +660,7 @@ class FirstFewValues { void Reset() { for (int i = 0; i < n; ++i) { - reachable_[i] = std::numeric_limits::max(); + reachable_[i] = kint64max; } reachable_[0] = 0; new_reachable_[0] = 0; @@ -758,7 +796,7 @@ class BasicKnapsackSolver { // We only need to keep one state with the same activity. struct State { - int64_t cost = std::numeric_limits::max(); + int64_t cost = kint64max; int64_t value = 0; }; std::vector> var_activity_states_; @@ -914,19 +952,53 @@ inline bool IsStrictlyIncluded(Bitset64::ConstView in_subset, inline int64_t SafeDoubleToInt64(double value) { if (std::isnan(value)) return 0; - if (value >= static_cast(std::numeric_limits::max())) { - return std::numeric_limits::max(); + if (value >= static_cast(kint64max)) { + return kint64max; } - if (value <= static_cast(std::numeric_limits::min())) { - return std::numeric_limits::min(); + if (value <= static_cast(kint64min)) { + return kint64min; } return static_cast(value); } // Tells whether a int128 can be casted to a int64_t that can be negated. inline bool IsNegatableInt64(absl::int128 x) { - return x <= absl::int128(std::numeric_limits::max()) && - x > absl::int128(std::numeric_limits::min()); + return x <= absl::int128(kint64max) && x > absl::int128(kint64min); +} + +namespace compact_vector_vector_detail { + +// Convert int and StrongInt to normal int. +template +inline int InternalKey(K key) { + if constexpr (std::is_same_v) { + return key; + } else { + return key.value(); + } +} + +} // namespace compact_vector_vector_detail + +template +void CompactVectorVectorBuilder::Add(const K& key, const V& value) { + using compact_vector_vector_detail::InternalKey; + max_key_ = std::max(max_key_, key); + key_buffer_.push_back(key); + value_buffer_.push_back(value); +} + +template +void CompactVectorVectorBuilder::ReserveNumItems(int64_t num_items) { + key_buffer_.reserve(num_items); + value_buffer_.reserve(num_items); +} + +template +void CompactVectorVectorBuilder::Clear() { + max_key_ = 0; + key_buffer_.clear(); + value_buffer_.clear(); } template @@ -960,6 +1032,22 @@ inline void CompactVectorVector::ReplaceValuesBySmallerSet( memcpy(&buffer_[starts_[key]], values.data(), sizeof(V) * values.size()); } +template +inline void CompactVectorVector::SortAndRemoveDuplicateValues(K key) { + using compact_vector_vector_detail::InternalKey; + DCHECK_GE(key, 0); + DCHECK_LT(key, starts_.size()); + DCHECK_LT(key, sizes_.size()); + const int k = InternalKey(key); + const size_t size = static_cast(sizes_.data()[k]); + if (size > 0) { + auto span = absl::MakeSpan(&buffer_.data()[starts_.data()[k]], size); + std::sort(span.begin(), span.end()); + sizes_.data()[k] = static_cast( + std::distance(span.begin(), std::unique(span.begin(), span.end()))); + } +} + template template inline int CompactVectorVector::AddLiterals( @@ -973,18 +1061,9 @@ inline int CompactVectorVector::AddLiterals( return index; } -// We need to support both StrongType and normal int. -template -inline int CompactVectorVector::InternalKey(K key) { - if constexpr (std::is_same_v) { - return key; - } else { - return key.value(); - } -} - template inline void CompactVectorVector::Shrink(K key, int new_size) { + using compact_vector_vector_detail::InternalKey; const int k = InternalKey(key); DCHECK_LE(new_size, sizes_[k]); sizes_[k] = new_size; @@ -992,6 +1071,7 @@ inline void CompactVectorVector::Shrink(K key, int new_size) { template inline absl::Span CompactVectorVector::operator[](K key) const { + using compact_vector_vector_detail::InternalKey; DCHECK_GE(key, 0); DCHECK_LT(key, starts_.size()); DCHECK_LT(key, sizes_.size()); @@ -1003,6 +1083,7 @@ inline absl::Span CompactVectorVector::operator[](K key) const { template inline absl::Span CompactVectorVector::operator[](K key) { + using compact_vector_vector_detail::InternalKey; DCHECK_GE(key, 0); DCHECK_LT(key, starts_.size()); DCHECK_LT(key, sizes_.size()); @@ -1043,6 +1124,7 @@ template template inline void CompactVectorVector::ResetFromFlatMapping( Keys keys, Values values, int minimum_num_nodes) { + using compact_vector_vector_detail::InternalKey; // Compute maximum index. int max_key = minimum_num_nodes; for (const K key : keys) { @@ -1081,11 +1163,56 @@ inline void CompactVectorVector::ResetFromFlatMapping( starts_[0] = 0; } +template +inline void CompactVectorVector::ResetFromBuilder( + const CompactVectorVectorBuilder& builder, int minimum_num_keys) { + using compact_vector_vector_detail::InternalKey; + + if (builder.key_buffer_.empty()) { + clear(); + sizes_.assign(minimum_num_keys, 0); + starts_.assign(minimum_num_keys, 0); + return; + } + + minimum_num_keys = + std::max(minimum_num_keys, InternalKey(builder.max_key_) + 1); + + // Compute sizes_; + sizes_.assign(minimum_num_keys, 0); + for (const K key : builder.key_buffer_) { + sizes_[InternalKey(key)]++; + } + + // Compute starts_; + starts_.assign(minimum_num_keys, 0); + for (int k = 1; k < minimum_num_keys; ++k) { + starts_[k] = starts_[k - 1] + sizes_[k - 1]; + } + + // Copy data and uses starts as temporary indices. + const int num_items = builder.key_buffer_.size(); + + buffer_.resize(num_items); + const K* keys = builder.key_buffer_.data(); + const V* values = builder.value_buffer_.data(); + for (int i = 0; i < num_items; ++i) { + buffer_[starts_[InternalKey(keys[i])]++] = values[i]; + } + + // Restore starts_. + for (int k = minimum_num_keys - 1; k > 0; --k) { + starts_[k] = starts_[k - 1]; + } + starts_[0] = 0; +} + // Similar to ResetFromFlatMapping(). template template inline void CompactVectorVector::ResetFromPairs(const Collection& pairs, int minimum_num_nodes) { + using compact_vector_vector_detail::InternalKey; // Compute maximum index. int max_key = minimum_num_nodes; for (const auto& [key, _] : pairs) { @@ -1130,6 +1257,7 @@ template template void CompactVectorVector::ResetFromTransposeMap(const Container& other, int min_transpose_size) { + using compact_vector_vector_detail::InternalKey; ValueMapper mapper; if (other.size() == 0) { clear(); diff --git a/ortools/sat/util_test.cc b/ortools/sat/util_test.cc index 610efb2d0cf..39cf9925d6e 100644 --- a/ortools/sat/util_test.cc +++ b/ortools/sat/util_test.cc @@ -39,6 +39,7 @@ #include "ortools/base/gmock.h" #include "ortools/base/mathutil.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_solver.h" #include "ortools/sat/cp_model_utils.h" @@ -130,6 +131,23 @@ TEST(CompactVectorVectorTest, ShrinkValues) { EXPECT_THAT(storage[2], ElementsAre(3, 4, 5)); } +TEST(CompactVectorVectorTest, SortAndRemoveDuplicateValues) { + CompactVectorVector storage; + EXPECT_EQ(storage.size(), 0); + + storage.ResetFromFlatMapping( + std::vector({1, 1, 2, 2, 2, 1, 1, 1}), + std::vector({14, 13, 22, 21, 22, 14, 10, 14})); + storage.SortAndRemoveDuplicateValues(0); + storage.SortAndRemoveDuplicateValues(1); + storage.SortAndRemoveDuplicateValues(2); + + EXPECT_EQ(storage.size(), 3); + EXPECT_THAT(storage[0], IsEmpty()); + EXPECT_THAT(storage[1], ElementsAre(10, 13, 14)); + EXPECT_THAT(storage[2], ElementsAre(21, 22)); +} + TEST(CompactVectorVectorTest, ResetFromTranspose) { CompactVectorVector storage; EXPECT_EQ(storage.size(), 0); @@ -226,7 +244,7 @@ TEST(ModularInverseTest, AllSmallValues) { TEST(ModularInverseTest, BasicOverflowTest) { absl::BitGen random; - const int64_t max = std::numeric_limits::max(); + const int64_t max = kint64max; for (int i = 0; i < 100000; ++i) { const int64_t m = max - absl::LogUniform(random, 0, max); const int64_t x = absl::Uniform(random, 0, m); @@ -316,7 +334,7 @@ TEST(SolveDiophantineEquationOfSizeTwoTest, FewSmallValues) { TEST(SolveDiophantineEquationOfSizeTwoTest, BasicOverflowTest) { absl::BitGen random; - const int64_t max = std::numeric_limits::max(); + const int64_t max = kint64max; for (int i = 0; i < 100000; ++i) { int64_t a = max - absl::LogUniform(random, 0, max); int64_t b = max - absl::LogUniform(random, 0, max); @@ -651,8 +669,8 @@ TEST(Percentile, RandomNumbers) { TEST(SafeDoubleToInt64Test, BasicCases) { const double kInfinity = std::numeric_limits::infinity(); - const int64_t kMax = std::numeric_limits::max(); - const int64_t kMin = std::numeric_limits::min(); + const int64_t kMax = kint64max; + const int64_t kMin = kint64min; const int64_t max53 = (int64_t{1} << 53) - 1; // Arbitrary behavior for nans. @@ -783,7 +801,7 @@ BENCHMARK(BM_bounded_subset_sum) TEST(FirstFewValuesTest, Basic) { FirstFewValues<8> values; - EXPECT_EQ(values.LastValue(), std::numeric_limits::max()); + EXPECT_EQ(values.LastValue(), kint64max); values.Add(3); EXPECT_THAT(values.reachable(), ElementsAre(0, 3, 6, 9, 12, 15, 18, 21)); values.Add(5); @@ -798,7 +816,7 @@ TEST(FirstFewValuesTest, Basic) { TEST(FirstFewValuesTest, Overflow) { FirstFewValues<6> values; - const int64_t max = std::numeric_limits::max(); + const int64_t max = kint64max; const int64_t v = max / 3; values.Add(v); EXPECT_THAT(values.reachable(), ElementsAre(0, v, 2 * v, 3 * v, max, max)); diff --git a/ortools/sat/var_domination.cc b/ortools/sat/var_domination.cc index 9cab7c79f03..7905dde5e87 100644 --- a/ortools/sat/var_domination.cc +++ b/ortools/sat/var_domination.cc @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -38,6 +37,7 @@ #include "ortools/base/mathutil.h" #include "ortools/base/stl_util.h" #include "ortools/base/strong_vector.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model.pb.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/integer_base.h" @@ -686,7 +686,7 @@ void TransformLinearWithSpecialBoolean(const ConstraintProto& ct, int ref, output->push_back(ct.enforcement_literal().size()); for (const int literal : ct.enforcement_literal()) { if (literal == NegatedRef(ref)) { - output->push_back(std::numeric_limits::max()); // Sentinel + output->push_back(kint32max); // Sentinel } else { output->push_back(literal); } @@ -701,11 +701,11 @@ void TransformLinearWithSpecialBoolean(const ConstraintProto& ct, int ref, const int v = ct.linear().vars(i); const int64_t c = ct.linear().coeffs(i); if (v == ref) { - output->push_back(std::numeric_limits::max()); // Sentinel + output->push_back(kint32max); // Sentinel output->push_back(c); } else if (v == NegatedRef(ref)) { // c * v = -c * (1 - v) + c - output->push_back(std::numeric_limits::max()); // Sentinel + output->push_back(kint32max); // Sentinel output->push_back(-c); offset += c; } else { @@ -725,7 +725,7 @@ void TransformLinearWithSpecialBoolean(const ConstraintProto& ct, int ref, bool DualBoundStrengthening::Strengthen(PresolveContext* context) { SolutionCrush& crush = context->solution_crush(); num_deleted_constraints_ = 0; - const CpModelProto& cp_model = *context->working_model; + const CpModelProto& cp_model = context->WorkingModel(); const int num_vars = cp_model.variables_size(); int64_t num_fixed_vars = 0; for (int var = 0; var < num_vars; ++var) { @@ -773,19 +773,15 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // has holes. if (lb_limit > lb || ub_limit < ub) { const int64_t new_ub = - ub_limit < ub - ? context->DomainOf(var) - .IntersectionWith( - Domain(ub_limit, std::numeric_limits::max())) - .Min() - : ub; + ub_limit < ub ? context->DomainOf(var) + .IntersectionWith(Domain(ub_limit, kint64max)) + .Min() + : ub; const int64_t new_lb = - lb_limit > lb - ? context->DomainOf(var) - .IntersectionWith( - Domain(std::numeric_limits::min(), lb_limit)) - .Max() - : lb; + lb_limit > lb ? context->DomainOf(var) + .IntersectionWith(Domain(kint64min, lb_limit)) + .Max() + : lb; context->UpdateRuleStats("dual: reduced domain"); CHECK(context->IntersectDomainWith(var, Domain(new_lb, new_ub))); } @@ -824,7 +820,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { } const int ct_index = locking_ct_index_[var]; - const ConstraintProto& ct = context->working_model->constraints(ct_index); + const ConstraintProto& ct = context->Constraint(ct_index); if (ct.constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { // TODO(user): Fix variable right away rather than waiting for next call. continue; @@ -911,7 +907,6 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // order to preserve the hint feasibility. crush.SetLiteralToValueIf(ref, false, NegatedRef(enf)); context->AddImplication(NegatedRef(enf), NegatedRef(ref)); - context->UpdateNewConstraintsVariableUsage(); continue; } @@ -937,7 +932,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { } processed[PositiveRef(ref)] = true; processed[PositiveRef(var)] = true; - context->working_model->mutable_constraints(ct_index)->Clear(); + context->ClearConstraint(ct_index); context->UpdateConstraintVariableUsage(ct_index); continue; } @@ -945,7 +940,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { context->UpdateRuleStats("linear1: always true"); processed[PositiveRef(ref)] = true; processed[PositiveRef(var)] = true; - context->working_model->mutable_constraints(ct_index)->Clear(); + context->ClearConstraint(ct_index); context->UpdateConstraintVariableUsage(ct_index); continue; } @@ -992,7 +987,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { return false; } } - context->working_model->mutable_constraints(ct_index)->Clear(); + context->ClearConstraint(ct_index); context->UpdateConstraintVariableUsage(ct_index); processed[PositiveRef(ref)] = true; processed[PositiveRef(var)] = true; @@ -1013,12 +1008,11 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // (`var` in `rhs`) -- which does not apply when `ref` is true. crush.SetLiteralToValueIfLinearConstraintViolated( ref, false, {{var, 1}}, complement); - ConstraintProto* new_ct = context->working_model->add_constraints(); + ConstraintProto* new_ct = context->AddConstraint(); new_ct->add_enforcement_literal(ref); new_ct->mutable_linear()->add_vars(var); new_ct->mutable_linear()->add_coeffs(1); FillDomainInProto(complement, new_ct->mutable_linear()); - context->UpdateNewConstraintsVariableUsage(); if (rhs.IsFixed()) { context->StoreLiteralImpliesVarEqValue(NegatedRef(ref), var, value); @@ -1027,7 +1021,6 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { context->StoreLiteralImpliesVarNeValue(NegatedRef(ref), var, value); context->StoreLiteralImpliesVarEqValue(ref, var, value); } - context->UpdateNewConstraintsVariableUsage(); continue; } } @@ -1064,8 +1057,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // Already present! const auto [other_c_with_same_hash, other_ref] = it->second; CHECK_NE(other_c_with_same_hash, ct_index); - const auto& other_ct = - context->working_model->constraints(other_c_with_same_hash); + const auto& other_ct = context->Constraint(other_c_with_same_hash); TransformLinearWithSpecialBoolean(other_ct, other_ref, &other_temp_data); if (temp_data == other_temp_data) { @@ -1100,7 +1092,7 @@ bool DualBoundStrengthening::Strengthen(PresolveContext* context) { // We can delete one of the constraint since they are duplicate // now. ++num_deleted_constraints_; - context->working_model->mutable_constraints(ct_index)->Clear(); + context->ClearConstraint(ct_index); context->UpdateConstraintVariableUsage(ct_index); continue; } @@ -1187,7 +1179,7 @@ void ScanModelForDominanceDetection(PresolveContext& context, VarDomination* var_domination) { if (context.ModelIsUnsat()) return; - const CpModelProto& cp_model = *context.working_model; + const CpModelProto& cp_model = context.WorkingModel(); const int num_vars = cp_model.variables().size(); var_domination->Reset(num_vars); @@ -1393,7 +1385,7 @@ void ScanModelForDualBoundStrengthening( const PresolveContext& context, DualBoundStrengthening* dual_bound_strengthening) { if (context.ModelIsUnsat()) return; - const CpModelProto& cp_model = *context.working_model; + const CpModelProto& cp_model = context.WorkingModel(); const int num_vars = cp_model.variables().size(); dual_bound_strengthening->Reset(num_vars); @@ -1538,7 +1530,7 @@ bool ProcessAtMostOne( bool ExploitDominanceRelations(const VarDomination& var_domination, PresolveContext* context) { - const CpModelProto& cp_model = *context->working_model; + const CpModelProto& cp_model = context->WorkingModel(); const int num_vars = cp_model.variables_size(); // Abort early if there is nothing to do. @@ -1562,7 +1554,7 @@ bool ExploitDominanceRelations(const VarDomination& var_domination, util_intops::StrongVector can_freely_decrease_count( num_vars * 2, 0); util_intops::StrongVector can_freely_decrease_until( - num_vars * 2, std::numeric_limits::min()); + num_vars * 2, kint64min); // Temporary data that we fill/clear for each linear constraint. util_intops::StrongVector var_lb_to_ub_diff( @@ -1954,7 +1946,6 @@ bool ExploitDominanceRelations(const VarDomination& var_domination, // call below fixes it by negating both values. Otherwise it does // nothing and thus preserves its feasibility. crush.UpdateLiteralsWithDominance(ref, dom_ref); - context->UpdateNewConstraintsVariableUsage(); implications.insert({ref, dom_ref}); implications.insert({NegatedRef(dom_ref), NegatedRef(ref)}); diff --git a/ortools/sat/var_domination_test.cc b/ortools/sat/var_domination_test.cc index 543d9729b01..d904dbdf0c0 100644 --- a/ortools/sat/var_domination_test.cc +++ b/ortools/sat/var_domination_test.cc @@ -122,7 +122,6 @@ TEST(VarDominationTest, ExploitDominanceRelation) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); @@ -173,7 +172,6 @@ TEST(VarDominationTest, ExploitDominanceRelationWithHoles) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); @@ -219,8 +217,7 @@ TEST(VarDominationTest, ExploitDominanceOfImplicant) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); @@ -268,8 +265,7 @@ TEST(VarDominationTest, ExploitDominanceOfNegatedImplicand) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); @@ -314,8 +310,7 @@ TEST(VarDominationTest, ExploitDominanceInExactlyOne) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); @@ -368,8 +363,7 @@ TEST(VarDominationTest, ExploitDominanceWithIntegerVariables) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); @@ -414,19 +408,17 @@ TEST(VarDominationTest, ExploitRemainingDominance) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); // Check that an implication between X and Y was added, and that the hint was // updated in consequence. - EXPECT_EQ(context.working_model->constraints_size(), 2); + EXPECT_EQ(context.NumConstraints(), 2); const ConstraintProto expected_constraint_proto = ParseTestProto(R"pb(enforcement_literal: -1 bool_and { literals: -2 })pb"); - EXPECT_THAT(context.working_model->constraints(1), - EqualsProto(expected_constraint_proto)); + EXPECT_THAT(context.Constraint(1), EqualsProto(expected_constraint_proto)); EXPECT_EQ(context.DomainOf(0).ToString(), "[0,1]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0,1]"); EXPECT_EQ(context.solution_crush().GetVarValues()[0], 1); @@ -481,8 +473,7 @@ TEST(VarDominationTest, ExploitRemainingDominanceWithIntegerVariables) { PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); context.ReadObjectiveFromProto(); - context.UpdateNewConstraintsVariableUsage(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); ScanModelForDominanceDetection(context, &var_dom); EXPECT_TRUE(ExploitDominanceRelations(var_dom, &context)); @@ -827,7 +818,7 @@ TEST(DualBoundReductionTest, FixVariableToDomainBound) { Model model; PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); context.ReadObjectiveFromProto(); ScanModelForDualBoundStrengthening(context, &dual_bound_strengthening); @@ -865,7 +856,7 @@ TEST(DualBoundReductionTest, BasicTest) { Model model; PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); context.ReadObjectiveFromProto(); ScanModelForDualBoundStrengthening(context, &dual_bound_strengthening); @@ -929,7 +920,7 @@ TEST(DualBoundReductionTest, Choices) { Model model; PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); context.ReadObjectiveFromProto(); ScanModelForDualBoundStrengthening(context, &dual_bound_strengthening); @@ -966,24 +957,22 @@ TEST(DualBoundReductionTest, AddImplication) { Model model; PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); context.ReadObjectiveFromProto(); ScanModelForDualBoundStrengthening(context, &dual_bound_strengthening); EXPECT_TRUE(dual_bound_strengthening.Strengthen(&context)); // not(a) => not(b) and not(a) => not(c) should be added. - ASSERT_EQ(context.working_model->constraints_size(), 3); + ASSERT_EQ(context.NumConstraints(), 3); const ConstraintProto expected_constraint_proto1 = ParseTestProto(R"pb(enforcement_literal: -1 bool_and { literals: -2 })pb"); - EXPECT_THAT(context.working_model->constraints(1), - EqualsProto(expected_constraint_proto1)); + EXPECT_THAT(context.Constraint(1), EqualsProto(expected_constraint_proto1)); const ConstraintProto expected_constraint_proto2 = ParseTestProto(R"pb(enforcement_literal: -1 bool_and { literals: -3 })pb"); - EXPECT_THAT(context.working_model->constraints(2), - EqualsProto(expected_constraint_proto2)); + EXPECT_THAT(context.Constraint(2), EqualsProto(expected_constraint_proto2)); EXPECT_EQ(context.DomainOf(0).ToString(), "[0]"); EXPECT_EQ(context.DomainOf(1).ToString(), "[0,1]"); EXPECT_EQ(context.DomainOf(2).ToString(), "[0,1]"); @@ -1024,7 +1013,7 @@ TEST(DualBoundReductionTest, EquivalenceDetection) { Model model; PresolveContext context(&model, &model_proto, nullptr); context.InitializeNewDomains(); - context.LoadSolutionHint(); + context.LoadAndClampSolutionHint(); context.ReadObjectiveFromProto(); ScanModelForDualBoundStrengthening(context, &dual_bound_strengthening); diff --git a/ortools/sat/variable_expand.cc b/ortools/sat/variable_expand.cc index 4c268e30347..eafd1123a19 100644 --- a/ortools/sat/variable_expand.cc +++ b/ortools/sat/variable_expand.cc @@ -28,6 +28,7 @@ #include "absl/strings/str_cat.h" #include "ortools/base/log_severity.h" #include "ortools/base/stl_util.h" +#include "ortools/base/types.h" #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/presolve_context.h" #include "ortools/sat/solution_crush.h" @@ -76,7 +77,7 @@ enum class EncodingLinear1Status { struct EncodingLinear1 { EncodingLinear1Type type; - int64_t value = std::numeric_limits::min(); + int64_t value = kint64min; Domain rhs; // Only used for kVarInDomain. int enforcement_literal; int constraint_index; @@ -305,8 +306,8 @@ void OrderEncoding::CreateAllOrderEncodingLiterals( const int64_t max_le_value = encoded_le_literal_.rbegin()->first; const int64_t max_ge_value = var_domain_.ValueAtOrAfter(max_le_value + 1); ConstraintProto* not_le = nullptr; - ConstraintProto* not_ge = context_->working_model->add_constraints(); - ConstraintProto* le = context_->working_model->add_constraints(); + ConstraintProto* not_ge = context_->AddConstraint(); + ConstraintProto* le = context_->AddConstraint(); ConstraintProto* ge = nullptr; for (const auto [value, eq_literal] : values.encoding()) { @@ -327,7 +328,7 @@ void OrderEncoding::CreateAllOrderEncodingLiterals( DCHECK(le != nullptr); le->add_enforcement_literal(le_literal); if (value < max_le_value) { - le = context_->working_model->add_constraints(); + le = context_->AddConstraint(); le->mutable_bool_or()->add_literals(le_literal); } else { le = nullptr; @@ -354,7 +355,7 @@ void OrderEncoding::CreateAllOrderEncodingLiterals( DCHECK(not_ge != nullptr); not_ge->add_enforcement_literal(ge_literal); if (value != max_ge_value) { - not_ge = context_->working_model->add_constraints(); + not_ge = context_->AddConstraint(); not_ge->mutable_bool_and()->add_literals(ge_literal); } else { not_ge = nullptr; @@ -479,7 +480,7 @@ bool ProcessEncodingConstraints( // Sort the constraint indices to make the encoding deterministic. absl::c_sort(constraint_indices); for (const int c : constraint_indices) { - const ConstraintProto& ct = context->working_model->constraints(c); + const ConstraintProto& ct = context->Constraint(c); DCHECK_EQ(ct.constraint_case(), ConstraintProto::kLinear); DCHECK_EQ(ct.linear().vars().size(), 1); DCHECK(RefIsPositive(ct.linear().vars(0))); @@ -760,8 +761,7 @@ void TryToReplaceVariableByItsEncoding(int var, PresolveContext* context, // Note, the use of exactly_one here is correct because this is a partition, // and the two equations complement each other. for (const EncodingLinear1& info_in : lin_domain) { - BoolArgumentProto* exo = - context->working_model->add_constraints()->mutable_exactly_one(); + BoolArgumentProto* exo = context->AddConstraint()->mutable_exactly_one(); exo->add_literals(NegatedRef(info_in.enforcement_literal)); for (const int64_t v : info_in.rhs.Values()) { exo->add_literals(values.literal(v)); @@ -788,7 +788,7 @@ void TryToReplaceVariableByItsEncoding(int var, PresolveContext* context, const int e_j = info_j.enforcement_literal; if (e_i == NegatedRef(e_j)) return; BoolArgumentProto* incompatible = - context->working_model->add_constraints()->mutable_bool_or(); + context->AddConstraint()->mutable_bool_or(); incompatible->add_literals(NegatedRef(e_i)); incompatible->add_literals(NegatedRef(e_j)); context->UpdateRuleStats( @@ -824,7 +824,6 @@ void TryToReplaceVariableByItsEncoding(int var, PresolveContext* context, } } } - context->UpdateNewConstraintsVariableUsage(); // Update the objective if needed. Note that this operation can fail if // the new expression result in potential overflow. @@ -845,7 +844,7 @@ void TryToReplaceVariableByItsEncoding(int var, PresolveContext* context, int64_t accumulated = std::abs(base_value); for (const int64_t value : values.encoded_values()) { accumulated = CapAdd(accumulated, std::abs(CapSub(value, base_value))); - if (accumulated == std::numeric_limits::max()) { + if (accumulated == kint64max) { VLOG(2) << "Abort - overflow when converting linear1 to clauses"; context->UpdateRuleStats( "TODO variables: overflow when converting linear1 to clauses"); @@ -916,18 +915,17 @@ void TryToReplaceVariableByItsEncoding(int var, PresolveContext* context, } absl::c_sort(to_clear); for (const int c : to_clear) { - context->working_model->mutable_constraints(c)->Clear(); + context->ClearConstraint(c); context->UpdateConstraintVariableUsage(c); } } // This must be done after we removed all the constraint containing var. - ConstraintProto* exo = context->working_model->add_constraints(); + ConstraintProto* exo = context->AddConstraint(); BoolArgumentProto* arg = exo->mutable_exactly_one(); for (const auto& [value, literal] : values.encoding()) { arg->add_literals(literal); } - context->UpdateNewConstraintsVariableUsage(); if (context->ModelIsUnsat()) return; // To simplify the postsolve, we output a single constraint to infer X from diff --git a/ortools/sat/work_assignment.cc b/ortools/sat/work_assignment.cc index e914eb67683..2059e08f056 100644 --- a/ortools/sat/work_assignment.cc +++ b/ortools/sat/work_assignment.cc @@ -18,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -33,6 +32,7 @@ #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/clause.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cp_model_utils.h" @@ -272,11 +272,11 @@ SharedTreeManager::SharedTreeManager(Model* model) model->GetOrCreate())), num_splits_wanted_( num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1), - max_nodes_( - params_.shared_tree_max_nodes_per_worker() >= - std::numeric_limits::max() / std::max(num_workers_, 1) - ? std::numeric_limits::max() - : num_workers_ * params_.shared_tree_max_nodes_per_worker()) { + max_nodes_(params_.shared_tree_max_nodes_per_worker() >= + kint32max / std::max(num_workers_, 1) + ? kint32max + : num_workers_ * + params_.shared_tree_max_nodes_per_worker()) { // Create the root node with a fake decision. nodes_.push_back( {.decision = ProtoLiteral(), @@ -1055,7 +1055,7 @@ SharedTreeWorker::SharedTreeWorker(Model* model) integer_trail_(model->GetOrCreate()), encoder_(model->GetOrCreate()), objective_(model->Get()), - random_(model->GetOrCreate()), + random_(*model->GetOrCreate()), helper_(model->GetOrCreate()), heuristics_(model->GetOrCreate()), decision_policy_(model->GetOrCreate()), diff --git a/ortools/sat/work_assignment.h b/ortools/sat/work_assignment.h index d5d8d8d1a8e..eea37322d1b 100644 --- a/ortools/sat/work_assignment.h +++ b/ortools/sat/work_assignment.h @@ -35,6 +35,7 @@ #include "absl/log/check.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "ortools/base/types.h" #include "ortools/sat/clause.h" #include "ortools/sat/cp_model_mapping.h" #include "ortools/sat/cp_model_utils.h" @@ -91,7 +92,7 @@ class ProtoLiteral { static std::optional EncodeInteger(IntegerLiteral, CpModelMapping*); - int proto_var_ = std::numeric_limits::max(); + int proto_var_ = kint32max; IntegerValue lb_ = kMaxIntegerValue; }; @@ -444,7 +445,7 @@ class SharedTreeWorker { IntegerTrail* integer_trail_; IntegerEncoder* encoder_; const ObjectiveDefinition* objective_; - ModelRandomGenerator* random_; + absl::BitGenRef random_; IntegerSearchHelper* helper_; SearchHeuristics* heuristics_; SatDecisionPolicy* decision_policy_; diff --git a/ortools/sat/zero_half_cuts.cc b/ortools/sat/zero_half_cuts.cc index b649f94d489..5385e2ee882 100644 --- a/ortools/sat/zero_half_cuts.cc +++ b/ortools/sat/zero_half_cuts.cc @@ -18,6 +18,7 @@ #include #include "absl/log/check.h" +#include "absl/random/bit_gen_ref.h" #include "absl/types/span.h" #include "ortools/lp_data/lp_types.h" #include "ortools/sat/integer_base.h" @@ -214,7 +215,7 @@ void ZeroHalfCutHelper::EliminateVarUsingRow(int eliminated_col, } std::vector>> -ZeroHalfCutHelper::InterestingCandidates(ModelRandomGenerator* random) { +ZeroHalfCutHelper::InterestingCandidates(absl::BitGenRef random) { std::vector>> result; // Remove singleton column from the picture. @@ -238,7 +239,7 @@ ZeroHalfCutHelper::InterestingCandidates(ModelRandomGenerator* random) { // Process rows by increasing size, but randomize if same size. std::vector to_process; for (int row = 0; row < rows_.size(); ++row) to_process.push_back(row); - std::shuffle(to_process.begin(), to_process.end(), *random); + std::shuffle(to_process.begin(), to_process.end(), random); std::stable_sort(to_process.begin(), to_process.end(), [this](int a, int b) { return rows_[a].cols.size() < rows_[b].cols.size(); }); diff --git a/ortools/sat/zero_half_cuts.h b/ortools/sat/zero_half_cuts.h index f4882375a02..60c99a668cd 100644 --- a/ortools/sat/zero_half_cuts.h +++ b/ortools/sat/zero_half_cuts.h @@ -52,7 +52,7 @@ class ZeroHalfCutHelper { absl::Span coeffs, IntegerValue lb, IntegerValue ub); std::vector>> - InterestingCandidates(ModelRandomGenerator* random); + InterestingCandidates(absl::BitGenRef random); // Visible for testing. void Reset(int size); diff --git a/ortools/util/bitset.h b/ortools/util/bitset.h index 8f5ddcc9c74..8a3b308f3a3 100644 --- a/ortools/util/bitset.h +++ b/ortools/util/bitset.h @@ -542,7 +542,7 @@ class Bitset64 { // Sets the bit at position i to 1. void Set(IndexType i) { DCHECK_GE(Value(i), 0); - DCHECK_LT(Value(i), size_); + DCHECK_LT(Value(i), Value(size_)); // The c++ hardening is costly here, so we disable it. data_.data()[BitOffset64(Value(i))] |= OneBit64(BitPos64(Value(i))); } diff --git a/ortools/util/status_macros.h b/ortools/util/status_macros.h index a5b06f3e22e..554fa118ad8 100644 --- a/ortools/util/status_macros.h +++ b/ortools/util/status_macros.h @@ -29,10 +29,14 @@ // WARNING: OR_ASSIGN_OR_RETURN3 expands into multiple statements; it cannot be // used in a single statement (e.g. as the body of an if statement without {})! // -#define OR_ASSIGN_OR_RETURN3(lhs, rexpr, error_expression) \ - OR_ASSIGN_OR_RETURN3_IMPL_( \ - STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __COUNTER__), lhs, rexpr, \ - error_expression) +#define OR_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y +#define OR_STATUS_MACROS_IMPL_CONCAT_(x, y) \ + OR_STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) + +#define OR_ASSIGN_OR_RETURN3(lhs, rexpr, error_expression) \ + OR_ASSIGN_OR_RETURN3_IMPL_( \ + OR_STATUS_MACROS_IMPL_CONCAT_(_status_or_value, __COUNTER__), lhs, \ + rexpr, error_expression) #define OR_ASSIGN_OR_RETURN3_IMPL_(statusor, lhs, rexpr, error_expression) \ auto statusor = (rexpr); \