Skip to content

Commit

Permalink
elaborating on local-search rephase strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Feb 7, 2023
1 parent f3ae769 commit 90a7586
Show file tree
Hide file tree
Showing 9 changed files with 131 additions and 56 deletions.
52 changes: 25 additions & 27 deletions src/sat/sat_ddfw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ namespace sat {
double sec = m_stopwatch.get_current_seconds();
double kflips_per_sec = (m_flips - m_last_flips) / (1000.0 * sec);
if (m_last_flips == 0) {
IF_VERBOSE(0, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts";
IF_VERBOSE(1, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts";
if (m_par) verbose_stream() << " :par";
verbose_stream() << ")\n");
}
IF_VERBOSE(0, verbose_stream() << "(sat.ddfw "
IF_VERBOSE(1, verbose_stream() << "(sat.ddfw "
<< std::setw(07) << m_min_sz
<< std::setw(07) << m_models.size()
<< std::setw(10) << kflips_per_sec
Expand Down Expand Up @@ -106,7 +106,6 @@ namespace sat {
if (r > 0) {
lim_pos -= score(r);
if (lim_pos <= 0) {
if (m_par) update_reward_avg(v);
return v;
}
}
Expand Down Expand Up @@ -139,9 +138,8 @@ namespace sat {
}

void ddfw::add(solver const& s) {
for (auto& ci : m_clauses) {
for (auto& ci : m_clauses)
m_alloc.del_clause(ci.m_clause);
}
m_clauses.reset();
m_use_list.reset();
m_num_non_binary_clauses = 0;
Expand Down Expand Up @@ -281,6 +279,7 @@ namespace sat {
ci.add(nlit);
}
value(v) = !value(v);
update_reward_avg(v);
}

bool ddfw::should_reinit_weights() {
Expand Down Expand Up @@ -379,36 +378,35 @@ namespace sat {
return m_par != nullptr && m_flips >= m_parsync_next;
}

void ddfw::save_priorities() {
m_probs.reset();
for (unsigned v = 0; v < num_vars(); ++v)
m_probs.push_back(-m_vars[v].m_reward_avg);
}

void ddfw::do_parallel_sync() {
if (m_par->from_solver(*this)) {
// Sum exp(xi) / exp(a) = Sum exp(xi - a)
double max_avg = 0;
for (unsigned v = 0; v < num_vars(); ++v) {
max_avg = std::max(max_avg, (double)m_vars[v].m_reward_avg);
}
double sum = 0;
for (unsigned v = 0; v < num_vars(); ++v) {
sum += exp(m_config.m_itau * (m_vars[v].m_reward_avg - max_avg));
}
if (sum == 0) {
sum = 0.01;
}
m_probs.reset();
for (unsigned v = 0; v < num_vars(); ++v) {
m_probs.push_back(exp(m_config.m_itau * (m_vars[v].m_reward_avg - max_avg)) / sum);
}
if (m_par->from_solver(*this))
m_par->to_solver(*this);
}

++m_parsync_count;
m_parsync_next *= 3;
m_parsync_next /= 2;
}

void ddfw::save_model() {
m_model.reserve(num_vars());
for (unsigned i = 0; i < num_vars(); ++i)
m_model[i] = to_lbool(value(i));
save_priorities();
}


void ddfw::save_best_values() {
if (m_unsat.empty()) {
m_model.reserve(num_vars());
for (unsigned i = 0; i < num_vars(); ++i)
m_model[i] = to_lbool(value(i));
if (m_unsat.empty())
save_model();
else if (m_unsat.size() < m_min_sz) {
if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11)
save_model();
}
if (m_unsat.size() < m_min_sz) {
m_models.reset();
Expand Down
2 changes: 2 additions & 0 deletions src/sat/sat_ddfw.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ namespace sat {
bool_var pick_var();
void flip(bool_var v);
void save_best_values();
void save_model();
void save_priorities();

// shift activity
void shift_weights();
Expand Down
20 changes: 13 additions & 7 deletions src/sat/sat_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,31 +214,37 @@ namespace sat {
}


bool parallel::_to_solver(solver& s) {
if (m_priorities.empty()) {
return false;
}
void parallel::_to_solver(solver& s) {
return;
#if 0
if (m_priorities.empty())
return;

for (bool_var v = 0; v < m_priorities.size(); ++v) {
s.update_activity(v, m_priorities[v]);
}
return true;
s.m_activity_inc = 128;
#endif
}

void parallel::from_solver(solver& s) {
lock_guard lock(m_mux);
_from_solver(s);
}

bool parallel::to_solver(solver& s) {
void parallel::to_solver(solver& s) {
lock_guard lock(m_mux);
return _to_solver(s);
_to_solver(s);
}

void parallel::_to_solver(i_local_search& s) {
return;
#if 0
m_priorities.reset();
for (bool_var v = 0; m_solver_copy && v < m_solver_copy->num_vars(); ++v) {
m_priorities.push_back(s.get_priority(v));
}
#endif
}

bool parallel::_from_solver(i_local_search& s) {
Expand Down
4 changes: 2 additions & 2 deletions src/sat/sat_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace sat {
bool enable_add(clause const& c) const;
void _get_clauses(solver& s);
void _from_solver(solver& s);
bool _to_solver(solver& s);
void _to_solver(solver& s);
bool _from_solver(i_local_search& s);
void _to_solver(i_local_search& s);

Expand Down Expand Up @@ -102,7 +102,7 @@ namespace sat {

// exchange from solver state to local search and back.
void from_solver(solver& s);
bool to_solver(solver& s);
void to_solver(solver& s);

bool from_solver(i_local_search& s);
void to_solver(i_local_search& s);
Expand Down
2 changes: 1 addition & 1 deletion src/sat/sat_params.pyg
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ def_module_params('sat',
export=True,
description='propositional SAT solver',
params=(max_memory_param(),
('phase', SYMBOL, 'caching', 'phase selection strategy: always_false, always_true, basic_caching, random, caching'),
('phase', SYMBOL, 'caching', 'phase selection strategy: always_false, always_true, basic_caching, random, caching, local_search'),
('phase.sticky', BOOL, True, 'use sticky phase caching'),
('search.unsat.conflicts', UINT, 400, 'period for solving for unsat (in number of conflicts)'),
('search.sat.conflicts', UINT, 400, 'period for solving for sat (in number of conflicts)'),
Expand Down
10 changes: 6 additions & 4 deletions src/sat/sat_prob.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace sat {
clause_vector m_clause_db;
svector<clause_info> m_clauses;
bool_vector m_values, m_best_values;
unsigned m_best_min_unsat{ 0 };
unsigned m_best_min_unsat = 0;
vector<unsigned_vector> m_use_list;
unsigned_vector m_flat_use_list;
unsigned_vector m_use_list_index;
Expand All @@ -67,9 +67,9 @@ namespace sat {
indexed_uint_set m_unsat;
random_gen m_rand;
unsigned_vector m_breaks;
uint64_t m_flips{ 0 };
uint64_t m_next_restart{ 0 };
unsigned m_restart_count{ 0 };
uint64_t m_flips = 0;
uint64_t m_next_restart = 0;
unsigned m_restart_count = 0;
stopwatch m_stopwatch;
model m_model;

Expand Down Expand Up @@ -139,6 +139,8 @@ namespace sat {
void add(solver const& s) override;

model const& get_model() const override { return m_model; }

double get_priority(bool_var v) const { return 0; }

std::ostream& display(std::ostream& out) const;

Expand Down
80 changes: 66 additions & 14 deletions src/sat/sat_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@ Revision History:

namespace sat {

/**
* Special cases of kissat style general backoff calculation.
* The version here calculates
* limit := value*log(C)^2*n*log(n)
* (effort calculation in kissat is based on ticks not clauses)
*
* respectively
* limit := conflicts + value*log(C)^2*n*log(n)
*/
void backoff::delta_effort(solver& s) {
count++;
unsigned d = value * count * log2(count + 1);
unsigned cl = log2(s.num_clauses() + 2);
limit = cl * cl * d;
}

void backoff::delta_conflicts(solver& s) {
delta_effort(s);
limit += s.m_conflicts_since_init;
}

solver::solver(params_ref const & p, reslimit& l):
solver_core(l),
Expand Down Expand Up @@ -1349,16 +1369,43 @@ namespace sat {
m_local_search->updt_params(m_params);
m_local_search->set_seed(m_rand());
scoped_rl.push_child(&(m_local_search->rlimit()));
m_local_search->rlimit().push(500000);

m_backoffs.m_local_search.delta_effort(*this);
m_local_search->rlimit().push(m_backoffs.m_local_search.limit);

m_local_search->reinit(*this);
lbool r = m_local_search->check(_lits.size(), _lits.data(), nullptr);
for (unsigned i = 0; i < m_phase.size(); ++i)
m_best_phase[i] = m_local_search->get_value(i);
if (r == l_true) {
m_conflicts_since_restart = 0;
m_conflicts_since_gc = 0;
m_next_simplify = std::max(m_next_simplify, m_conflicts_since_init + 1);
auto const& mdl = m_local_search->get_model();
if (mdl.size() == m_best_phase.size()) {
for (unsigned i = 0; i < m_best_phase.size(); ++i)
m_best_phase[i] = l_true == mdl[i];

if (r == l_true) {
m_conflicts_since_restart = 0;
m_conflicts_since_gc = 0;
m_next_simplify = std::max(m_next_simplify, m_conflicts_since_init + 1);
}
do_restart(true);
#if 0
// move higher priority variables to front
// eg., move the first 10% variables to front
svector<std::pair<double, bool_var>> priorities(mdl.size());
for (unsigned i = 0; i < mdl.size(); ++i)
priorities[i] = { m_local_search->get_priority(i), i };
std::sort(priorities.begin(), priorities.end(), [](auto& x, auto& y) { return x.first > y.first; });
for (unsigned i = priorities.size() / 10; i-- > 0; )
move_to_front(priorities[i].second);
#endif


if (l_true == r) {
for (clause const* cp : m_clauses) {
bool is_true = any_of(*cp, [&](auto lit) { return lit.sign() != m_best_phase[lit.var()]; });
if (!is_true) {
verbose_stream() << "clause is false " << *cp << "\n";
}
}
}
}
}

Expand Down Expand Up @@ -1693,7 +1740,7 @@ namespace sat {

if (!is_pos)
next_lit.neg();

TRACE("sat_decide", tout << scope_lvl() << ": next-case-split: " << next_lit << "\n";);
assign_scoped(next_lit);
return true;
Expand Down Expand Up @@ -1913,6 +1960,7 @@ namespace sat {
m_rephase_lim = 0;
m_rephase_inc = 0;
m_reorder_lim = m_config.m_reorder_base;
m_backoffs.m_local_search.value = 500;
m_reorder_inc = 0;
m_conflicts_since_restart = 0;
m_force_conflict_analysis = false;
Expand All @@ -1928,6 +1976,7 @@ namespace sat {
m_next_simplify = m_config.m_simplify_delay;
m_min_d_tk = 1.0;
m_search_lvl = 0;

if (m_learned.size() <= 2*m_clauses.size())
m_conflicts_since_gc = 0;
m_restart_next_out = 0;
Expand Down Expand Up @@ -2030,9 +2079,7 @@ namespace sat {

if (m_par) {
m_par->from_solver(*this);
if (m_par->to_solver(*this)) {
m_activity_inc = 128;
}
m_par->to_solver(*this);
}

if (m_config.m_binspr && !inconsistent()) {
Expand Down Expand Up @@ -2944,14 +2991,19 @@ namespace sat {
case PS_SAT_CACHING:
if (m_search_state == s_sat)
for (unsigned i = 0; i < m_phase.size(); ++i)
m_phase[i] = m_best_phase[i];
m_phase[i] = m_best_phase[i];
break;
case PS_RANDOM:
for (auto& p : m_phase) p = (m_rand() % 2) == 0;
break;
case PS_LOCAL_SEARCH:
if (m_search_state == s_sat)
bounded_local_search();
if (m_search_state == s_sat) {
if (m_rand() % 2 == 0)
bounded_local_search();
for (unsigned i = 0; i < m_phase.size(); ++i)
m_phase[i] = m_best_phase[i];
}

break;
default:
UNREACHABLE();
Expand Down
15 changes: 15 additions & 0 deletions src/sat/sat_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,23 @@ namespace sat {
struct no_drat_params : public params_ref {
no_drat_params() { set_bool("drat.disable", true); }
};

struct backoff {
unsigned value = 1;
unsigned lo = 0;
unsigned hi = 0;
unsigned limit = 0;
unsigned count = 0;
void delta_effort(solver& s);
void delta_conflicts(solver& s);
};

class solver : public solver_core {
public:
struct abort_solver {};
struct backoffs {
backoff m_local_search;
};
protected:
enum search_state { s_sat, s_unsat };

Expand Down Expand Up @@ -159,6 +172,7 @@ namespace sat {
unsigned m_search_next_toggle;
unsigned m_phase_counter;
unsigned m_best_phase_size;
backoffs m_backoffs;
unsigned m_rephase_lim;
unsigned m_rephase_inc;
unsigned m_reorder_lim;
Expand Down Expand Up @@ -237,6 +251,7 @@ namespace sat {
friend class lut_finder;
friend class npn3_finder;
friend class proof_trim;
friend struct backoff;
public:
solver(params_ref const & p, reslimit& l);
~solver() override;
Expand Down
2 changes: 1 addition & 1 deletion src/sat/sat_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ namespace sat {
virtual reslimit& rlimit() = 0;
virtual model const& get_model() const = 0;
virtual void collect_statistics(statistics& st) const = 0;
virtual double get_priority(bool_var v) const { return 0; }
virtual double get_priority(bool_var v) const = 0;
virtual bool get_value(bool_var v) const { return true; }
};

Expand Down

0 comments on commit 90a7586

Please sign in to comment.