From 90a75866fbe1eb66813444ea7c7d0d9a26046fb4 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Tue, 7 Feb 2023 03:17:52 -0800 Subject: [PATCH] elaborating on local-search rephase strategy --- src/sat/sat_ddfw.cpp | 52 +++++++++++++------------- src/sat/sat_ddfw.h | 2 + src/sat/sat_parallel.cpp | 20 ++++++---- src/sat/sat_parallel.h | 4 +- src/sat/sat_params.pyg | 2 +- src/sat/sat_prob.h | 10 +++-- src/sat/sat_solver.cpp | 80 +++++++++++++++++++++++++++++++++------- src/sat/sat_solver.h | 15 ++++++++ src/sat/sat_types.h | 2 +- 9 files changed, 131 insertions(+), 56 deletions(-) diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index f1493232c80..418070c64c9 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -57,11 +57,11 @@ namespace sat { double sec = m_stopwatch.get_current_seconds(); double kflips_per_sec = (m_flips - m_last_flips) / (1000.0 * sec); if (m_last_flips == 0) { - IF_VERBOSE(0, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts"; + IF_VERBOSE(1, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts"; if (m_par) verbose_stream() << " :par"; verbose_stream() << ")\n"); } - IF_VERBOSE(0, verbose_stream() << "(sat.ddfw " + IF_VERBOSE(1, verbose_stream() << "(sat.ddfw " << std::setw(07) << m_min_sz << std::setw(07) << m_models.size() << std::setw(10) << kflips_per_sec @@ -106,7 +106,6 @@ namespace sat { if (r > 0) { lim_pos -= score(r); if (lim_pos <= 0) { - if (m_par) update_reward_avg(v); return v; } } @@ -139,9 +138,8 @@ namespace sat { } void ddfw::add(solver const& s) { - for (auto& ci : m_clauses) { + for (auto& ci : m_clauses) m_alloc.del_clause(ci.m_clause); - } m_clauses.reset(); m_use_list.reset(); m_num_non_binary_clauses = 0; @@ -281,6 +279,7 @@ namespace sat { ci.add(nlit); } value(v) = !value(v); + update_reward_avg(v); } bool ddfw::should_reinit_weights() { @@ -379,36 +378,35 @@ namespace sat { return m_par != nullptr && m_flips >= m_parsync_next; } + void ddfw::save_priorities() { + m_probs.reset(); + for (unsigned v = 0; v < num_vars(); ++v) + m_probs.push_back(-m_vars[v].m_reward_avg); + } + void ddfw::do_parallel_sync() { - if (m_par->from_solver(*this)) { - // Sum exp(xi) / exp(a) = Sum exp(xi - a) - double max_avg = 0; - for (unsigned v = 0; v < num_vars(); ++v) { - max_avg = std::max(max_avg, (double)m_vars[v].m_reward_avg); - } - double sum = 0; - for (unsigned v = 0; v < num_vars(); ++v) { - sum += exp(m_config.m_itau * (m_vars[v].m_reward_avg - max_avg)); - } - if (sum == 0) { - sum = 0.01; - } - m_probs.reset(); - for (unsigned v = 0; v < num_vars(); ++v) { - m_probs.push_back(exp(m_config.m_itau * (m_vars[v].m_reward_avg - max_avg)) / sum); - } + if (m_par->from_solver(*this)) m_par->to_solver(*this); - } + ++m_parsync_count; m_parsync_next *= 3; m_parsync_next /= 2; } + void ddfw::save_model() { + m_model.reserve(num_vars()); + for (unsigned i = 0; i < num_vars(); ++i) + m_model[i] = to_lbool(value(i)); + save_priorities(); + } + + void ddfw::save_best_values() { - if (m_unsat.empty()) { - m_model.reserve(num_vars()); - for (unsigned i = 0; i < num_vars(); ++i) - m_model[i] = to_lbool(value(i)); + if (m_unsat.empty()) + save_model(); + else if (m_unsat.size() < m_min_sz) { + if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11) + save_model(); } if (m_unsat.size() < m_min_sz) { m_models.reset(); diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index ed9936f0ac9..d5e7df77374 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -164,6 +164,8 @@ namespace sat { bool_var pick_var(); void flip(bool_var v); void save_best_values(); + void save_model(); + void save_priorities(); // shift activity void shift_weights(); diff --git a/src/sat/sat_parallel.cpp b/src/sat/sat_parallel.cpp index 2f7a195588b..3e493168a0c 100644 --- a/src/sat/sat_parallel.cpp +++ b/src/sat/sat_parallel.cpp @@ -214,14 +214,17 @@ namespace sat { } - bool parallel::_to_solver(solver& s) { - if (m_priorities.empty()) { - return false; - } + void parallel::_to_solver(solver& s) { + return; +#if 0 + if (m_priorities.empty()) + return; + for (bool_var v = 0; v < m_priorities.size(); ++v) { s.update_activity(v, m_priorities[v]); } - return true; + s.m_activity_inc = 128; +#endif } void parallel::from_solver(solver& s) { @@ -229,16 +232,19 @@ namespace sat { _from_solver(s); } - bool parallel::to_solver(solver& s) { + void parallel::to_solver(solver& s) { lock_guard lock(m_mux); - return _to_solver(s); + _to_solver(s); } void parallel::_to_solver(i_local_search& s) { + return; +#if 0 m_priorities.reset(); for (bool_var v = 0; m_solver_copy && v < m_solver_copy->num_vars(); ++v) { m_priorities.push_back(s.get_priority(v)); } +#endif } bool parallel::_from_solver(i_local_search& s) { diff --git a/src/sat/sat_parallel.h b/src/sat/sat_parallel.h index 68266760ba0..65ae091835e 100644 --- a/src/sat/sat_parallel.h +++ b/src/sat/sat_parallel.h @@ -51,7 +51,7 @@ namespace sat { bool enable_add(clause const& c) const; void _get_clauses(solver& s); void _from_solver(solver& s); - bool _to_solver(solver& s); + void _to_solver(solver& s); bool _from_solver(i_local_search& s); void _to_solver(i_local_search& s); @@ -102,7 +102,7 @@ namespace sat { // exchange from solver state to local search and back. void from_solver(solver& s); - bool to_solver(solver& s); + void to_solver(solver& s); bool from_solver(i_local_search& s); void to_solver(i_local_search& s); diff --git a/src/sat/sat_params.pyg b/src/sat/sat_params.pyg index f9d7c643ac6..6aedf1c89e9 100644 --- a/src/sat/sat_params.pyg +++ b/src/sat/sat_params.pyg @@ -2,7 +2,7 @@ def_module_params('sat', export=True, description='propositional SAT solver', params=(max_memory_param(), - ('phase', SYMBOL, 'caching', 'phase selection strategy: always_false, always_true, basic_caching, random, caching'), + ('phase', SYMBOL, 'caching', 'phase selection strategy: always_false, always_true, basic_caching, random, caching, local_search'), ('phase.sticky', BOOL, True, 'use sticky phase caching'), ('search.unsat.conflicts', UINT, 400, 'period for solving for unsat (in number of conflicts)'), ('search.sat.conflicts', UINT, 400, 'period for solving for sat (in number of conflicts)'), diff --git a/src/sat/sat_prob.h b/src/sat/sat_prob.h index 305e76b8bcc..f05365e3902 100644 --- a/src/sat/sat_prob.h +++ b/src/sat/sat_prob.h @@ -58,7 +58,7 @@ namespace sat { clause_vector m_clause_db; svector m_clauses; bool_vector m_values, m_best_values; - unsigned m_best_min_unsat{ 0 }; + unsigned m_best_min_unsat = 0; vector m_use_list; unsigned_vector m_flat_use_list; unsigned_vector m_use_list_index; @@ -67,9 +67,9 @@ namespace sat { indexed_uint_set m_unsat; random_gen m_rand; unsigned_vector m_breaks; - uint64_t m_flips{ 0 }; - uint64_t m_next_restart{ 0 }; - unsigned m_restart_count{ 0 }; + uint64_t m_flips = 0; + uint64_t m_next_restart = 0; + unsigned m_restart_count = 0; stopwatch m_stopwatch; model m_model; @@ -139,6 +139,8 @@ namespace sat { void add(solver const& s) override; model const& get_model() const override { return m_model; } + + double get_priority(bool_var v) const { return 0; } std::ostream& display(std::ostream& out) const; diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 29710ed29ed..373885f341b 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -40,6 +40,26 @@ Revision History: namespace sat { + /** + * Special cases of kissat style general backoff calculation. + * The version here calculates + * limit := value*log(C)^2*n*log(n) + * (effort calculation in kissat is based on ticks not clauses) + * + * respectively + * limit := conflicts + value*log(C)^2*n*log(n) + */ + void backoff::delta_effort(solver& s) { + count++; + unsigned d = value * count * log2(count + 1); + unsigned cl = log2(s.num_clauses() + 2); + limit = cl * cl * d; + } + + void backoff::delta_conflicts(solver& s) { + delta_effort(s); + limit += s.m_conflicts_since_init; + } solver::solver(params_ref const & p, reslimit& l): solver_core(l), @@ -1349,16 +1369,43 @@ namespace sat { m_local_search->updt_params(m_params); m_local_search->set_seed(m_rand()); scoped_rl.push_child(&(m_local_search->rlimit())); - m_local_search->rlimit().push(500000); + + m_backoffs.m_local_search.delta_effort(*this); + m_local_search->rlimit().push(m_backoffs.m_local_search.limit); + m_local_search->reinit(*this); lbool r = m_local_search->check(_lits.size(), _lits.data(), nullptr); - for (unsigned i = 0; i < m_phase.size(); ++i) - m_best_phase[i] = m_local_search->get_value(i); - if (r == l_true) { - m_conflicts_since_restart = 0; - m_conflicts_since_gc = 0; - m_next_simplify = std::max(m_next_simplify, m_conflicts_since_init + 1); + auto const& mdl = m_local_search->get_model(); + if (mdl.size() == m_best_phase.size()) { + for (unsigned i = 0; i < m_best_phase.size(); ++i) + m_best_phase[i] = l_true == mdl[i]; + + if (r == l_true) { + m_conflicts_since_restart = 0; + m_conflicts_since_gc = 0; + m_next_simplify = std::max(m_next_simplify, m_conflicts_since_init + 1); + } do_restart(true); +#if 0 + // move higher priority variables to front + // eg., move the first 10% variables to front + svector> priorities(mdl.size()); + for (unsigned i = 0; i < mdl.size(); ++i) + priorities[i] = { m_local_search->get_priority(i), i }; + std::sort(priorities.begin(), priorities.end(), [](auto& x, auto& y) { return x.first > y.first; }); + for (unsigned i = priorities.size() / 10; i-- > 0; ) + move_to_front(priorities[i].second); +#endif + + + if (l_true == r) { + for (clause const* cp : m_clauses) { + bool is_true = any_of(*cp, [&](auto lit) { return lit.sign() != m_best_phase[lit.var()]; }); + if (!is_true) { + verbose_stream() << "clause is false " << *cp << "\n"; + } + } + } } } @@ -1693,7 +1740,7 @@ namespace sat { if (!is_pos) next_lit.neg(); - + TRACE("sat_decide", tout << scope_lvl() << ": next-case-split: " << next_lit << "\n";); assign_scoped(next_lit); return true; @@ -1913,6 +1960,7 @@ namespace sat { m_rephase_lim = 0; m_rephase_inc = 0; m_reorder_lim = m_config.m_reorder_base; + m_backoffs.m_local_search.value = 500; m_reorder_inc = 0; m_conflicts_since_restart = 0; m_force_conflict_analysis = false; @@ -1928,6 +1976,7 @@ namespace sat { m_next_simplify = m_config.m_simplify_delay; m_min_d_tk = 1.0; m_search_lvl = 0; + if (m_learned.size() <= 2*m_clauses.size()) m_conflicts_since_gc = 0; m_restart_next_out = 0; @@ -2030,9 +2079,7 @@ namespace sat { if (m_par) { m_par->from_solver(*this); - if (m_par->to_solver(*this)) { - m_activity_inc = 128; - } + m_par->to_solver(*this); } if (m_config.m_binspr && !inconsistent()) { @@ -2944,14 +2991,19 @@ namespace sat { case PS_SAT_CACHING: if (m_search_state == s_sat) for (unsigned i = 0; i < m_phase.size(); ++i) - m_phase[i] = m_best_phase[i]; + m_phase[i] = m_best_phase[i]; break; case PS_RANDOM: for (auto& p : m_phase) p = (m_rand() % 2) == 0; break; case PS_LOCAL_SEARCH: - if (m_search_state == s_sat) - bounded_local_search(); + if (m_search_state == s_sat) { + if (m_rand() % 2 == 0) + bounded_local_search(); + for (unsigned i = 0; i < m_phase.size(); ++i) + m_phase[i] = m_best_phase[i]; + } + break; default: UNREACHABLE(); diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 524f6b06de3..ca738ce9b14 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -87,10 +87,23 @@ namespace sat { struct no_drat_params : public params_ref { no_drat_params() { set_bool("drat.disable", true); } }; + + struct backoff { + unsigned value = 1; + unsigned lo = 0; + unsigned hi = 0; + unsigned limit = 0; + unsigned count = 0; + void delta_effort(solver& s); + void delta_conflicts(solver& s); + }; class solver : public solver_core { public: struct abort_solver {}; + struct backoffs { + backoff m_local_search; + }; protected: enum search_state { s_sat, s_unsat }; @@ -159,6 +172,7 @@ namespace sat { unsigned m_search_next_toggle; unsigned m_phase_counter; unsigned m_best_phase_size; + backoffs m_backoffs; unsigned m_rephase_lim; unsigned m_rephase_inc; unsigned m_reorder_lim; @@ -237,6 +251,7 @@ namespace sat { friend class lut_finder; friend class npn3_finder; friend class proof_trim; + friend struct backoff; public: solver(params_ref const & p, reslimit& l); ~solver() override; diff --git a/src/sat/sat_types.h b/src/sat/sat_types.h index c92a8bbeb4f..3026b3c5e00 100644 --- a/src/sat/sat_types.h +++ b/src/sat/sat_types.h @@ -90,7 +90,7 @@ namespace sat { virtual reslimit& rlimit() = 0; virtual model const& get_model() const = 0; virtual void collect_statistics(statistics& st) const = 0; - virtual double get_priority(bool_var v) const { return 0; } + virtual double get_priority(bool_var v) const = 0; virtual bool get_value(bool_var v) const { return true; } };