Skip to content

Commit

Permalink
wip - local search - move to plugin model
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolajBjorner committed Feb 15, 2023
1 parent a1f73d3 commit c1ecc49
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 59 deletions.
9 changes: 7 additions & 2 deletions src/sat/sat_ddfw.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include "sat/sat_clause.h"
#include "sat/sat_types.h"

namespace arith {
class sls;
}

namespace sat {
class solver;
class parallel;
Expand All @@ -44,6 +48,7 @@ namespace sat {
};

class ddfw : public i_local_search {
friend class arith::sls;
public:
struct clause_info {
clause_info(clause* cl, double init_weight): m_weight(init_weight), m_clause(cl) {}
Expand Down Expand Up @@ -126,7 +131,7 @@ namespace sat {
stopwatch m_stopwatch;

parallel* m_par;
scoped_ptr< local_search_plugin> m_plugin;
local_search_plugin* m_plugin = nullptr;

void flatten_use_list();

Expand All @@ -148,7 +153,7 @@ namespace sat {

inline double reward(bool_var v) const { return m_vars[v].m_reward; }

inline double plugin_reward(bool_var v) const { return m_plugin->reward(v); }
inline double plugin_reward(bool_var v) const { return is_external(v) ? m_plugin->reward(v) : reward(v); }

void set_external(bool_var v) { m_vars[v].m_external = true; }

Expand Down
90 changes: 90 additions & 0 deletions src/sat/sat_solver/sat_smt_setup.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*++
Copyright (c) 2006 Microsoft Corporation
Module Name:
sat_smt_setup.h
Author:
Nikolaj Bjorner (nbjorner) 2023-01-17
--*/
#pragma once

#include "ast/ast.h"
#include "smt/params/smt_params.h"
#include "sat/sat_config.h"
#include "ast/simplifiers/dependent_expr_state.h"

struct static_features;

namespace sat_smt {

void setup_sat_config(smt_params const& p, sat::config& config);

class setup {
ast_manager& m;
dependent_expr_state& m_st;
smt_params& m_params;
symbol m_logic;
bool m_already_configured = false;

void setup_auto_config();
void setup_default();
//
// setup_<logic>() methods do not depend on static features of the formula. So, they are safe to use
// even in an incremental setting.
//
// setup_<logic>(static_features & st) can only be used if the logical context will perform a single
// check.
//
void setup_QF_DT();
void setup_QF_UF();
void setup_QF_UF(static_features const & st);
void setup_QF_RDL();
void setup_QF_RDL(static_features & st);
void setup_QF_IDL();
void setup_QF_IDL(static_features & st);
void setup_QF_UFIDL();
void setup_QF_UFIDL(static_features & st);
void setup_QF_LRA();
void setup_QF_LRA(static_features const & st);
void setup_QF_LIA();
void setup_QF_LIRA(static_features const& st);
void setup_QF_LIA(static_features const & st);
void setup_QF_UFLIA();
void setup_QF_UFLIA(static_features & st);
void setup_QF_UFLRA();
void setup_QF_BV();
void setup_QF_AUFBV();
void setup_QF_AX();
void setup_QF_AX(static_features const & st);
void setup_QF_AUFLIA();
void setup_QF_AUFLIA(static_features const & st);
void setup_QF_FP();
void setup_QF_FPBV();
void setup_QF_S();
void setup_LRA();
void setup_CSP();
void setup_AUFLIA(bool simple_array = true);
void setup_AUFLIA(static_features const & st);
void setup_AUFLIRA(bool simple_array = true);
void setup_UFNIA();
void setup_UFLRA();
void setup_AUFLIAp();
void setup_AUFNIRA();
void setup_QF_BVRE();
void setup_unknown();
void setup_unknown(static_features & st);

public:
setup(ast_manager& m, dependent_expr_state& st, smt_params & params);
void setk_already_configured() { m_already_configured = true; }
bool already_configured() const { return m_already_configured; }
symbol const & get_logic() const { return m_logic; }
void operator()();
};
};


127 changes: 108 additions & 19 deletions src/sat/smt/arith_sls.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Copyright (c) 2023 Microsoft Corporation
Module Name:
Expand Down Expand Up @@ -112,6 +112,8 @@ namespace arith {
for (unsigned v = 0; v < s.s().num_vars(); ++v)
init_bool_var_assignment(v);
m_best_min_unsat = std::numeric_limits<unsigned>::max();

d->set(this);
}

void sls::set_bounds_begin() {
Expand Down Expand Up @@ -209,35 +211,46 @@ namespace arith {
unsigned start = s.random();
unsigned sz = unsat().size();
for (unsigned i = sz; i-- > 0; )
if (flip(unsat().elem_at((i + start) % sz)))
if (flip_clause(unsat().elem_at((i + start) % sz)))
return true;
return false;
}

bool sls::flip(unsigned cl) {
bool sls::flip_clause(unsigned cl) {
auto const& clause = get_clause(cl);
int64_t new_value;
for (literal lit : clause) {
if (is_true(lit))
continue;
auto const* ineq = atom(lit);
if (!ineq)
continue;
SASSERT(!ineq->is_true());
for (auto const& [coeff, v] : ineq->m_args) {
if (!cm(*ineq, v, new_value))
continue;
int score = cm_score(v, new_value);
if (score <= 0)
continue;
unsigned num_unsat = unsat().size();
update(v, new_value);
IF_VERBOSE(2,
verbose_stream() << "v" << v << " score " << score << " "
<< num_unsat << " -> " << unsat().size() << "\n");
SASSERT(num_unsat > unsat().size());
if (flip(*ineq))
return true;
}

}
return false;
}

// flip on the first positive score
// it could be changed to flip on maximal positive score
// or flip on maximal non-negative score
// or flip on first non-negative score
bool sls::flip(ineq const& ineq) {
int64_t new_value;
for (auto const& [coeff, v] : ineq.m_args) {
if (!cm(ineq, v, new_value))
continue;
int score = cm_score(v, new_value);
if (score <= 0)
continue;
unsigned num_unsat = unsat().size();
update(v, new_value);
IF_VERBOSE(2,
verbose_stream() << "v" << v << " score " << score << " "
<< num_unsat << " -> " << unsat().size() << "\n");
SASSERT(num_unsat > unsat().size());
return true;
}
return false;
}
Expand All @@ -246,7 +259,7 @@ namespace arith {
unsigned start = s.random();
unsigned sz = m_bool_search->num_clauses();
for (unsigned i = sz; i-- > 0; )
if (flip((i + start) % sz))
if (flip_clause((i + start) % sz))
return true;
return false;
}
Expand Down Expand Up @@ -541,9 +554,85 @@ namespace arith {

void sls::init_literal_assignment(sat::literal lit) {
auto* ineq = m_literals.get(lit.index(), nullptr);

if (ineq && is_true(lit) != (dtt(*ineq) == 0))
m_bool_search->flip(lit.var());
}

void sls::init_search() {
on_restart();
}

void sls::finish_search() {
store_best_values();
}

void sls::flip(sat::bool_var v) {
sat::literal lit(v, m_bool_search->get_value(v));
SASSERT(!is_true(lit));
auto const* ineq = atom(lit);
if (!ineq)
IF_VERBOSE(0, verbose_stream() << "no inequality for variable " << v << "\n");
if (!ineq)
return;
IF_VERBOSE(1, verbose_stream() << "flip " << lit << "\n");
SASSERT(!ineq->is_true());
flip(*ineq);
}

double sls::reward(sat::bool_var v) {
if (m_dscore_mode)
return dscore_reward(v);
else
return dtt_reward(v);
}

double sls::dtt_reward(sat::bool_var v) {
sat::literal litv(v, m_bool_search->get_value(v));
auto const* ineq = atom(litv);
if (!ineq)
return 0;
int64_t new_value;
double result = 0;
for (auto const & [coeff, x] : ineq->m_args) {
if (!cm(*ineq, x, new_value))
continue;
for (auto const [coeff, lit] : m_vars[x].m_literals) {
auto dtt_old = dtt(*atom(lit));
auto dtt_new = dtt(*atom(lit), x, new_value);
if ((dtt_new == 0) != (dtt_old == 0))
result += m_bool_search->reward(lit.var());
}
}
return result;
}

double sls::dscore_reward(sat::bool_var x) {
m_dscore_mode = false;
sat::literal litv(x, m_bool_search->get_value(x));
auto const* ineq = atom(litv);
if (!ineq)
return 0;
SASSERT(!ineq->is_true());
int64_t new_value;
double result = 0;
for (auto const& [coeff, v] : ineq->m_args)
if (cm(*ineq, v, new_value))
result += dscore(v, new_value);
return result;
}

// switch to dscore mode
void sls::on_rescale() {
m_dscore_mode = true;
}

void sls::on_save_model() {
save_best_values();
}

void sls::on_restart() {
for (unsigned v = 0; v < s.s().num_vars(); ++v)
init_bool_var_assignment(v);
}
}

21 changes: 17 additions & 4 deletions src/sat/smt/arith_sls.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace arith {
class solver;

// local search portion for arithmetic
class sls {
class sls : public sat::local_search_plugin {
enum class ineq_kind { EQ, LE, LT, NE };
enum class var_kind { INT, REAL };
typedef unsigned var_t;
Expand Down Expand Up @@ -78,7 +78,7 @@ namespace arith {
std::ostream& display(std::ostream& out) const {
bool first = true;
for (auto const& [c, v] : m_args)
out << (first? "": " + ") << c << " * v" << v, first = false;
out << (first ? "" : " + ") << c << " * v" << v, first = false;
switch (m_op) {
case ineq_kind::LE:
return out << " <= " << m_bound << "(" << m_args_value << ")";
Expand All @@ -97,7 +97,7 @@ namespace arith {
int64_t m_value;
int64_t m_best_value;
var_kind m_kind = var_kind::INT;
vector<std::pair<int64_t, sat::literal>> m_literals;
svector<std::pair<int64_t, sat::literal>> m_literals;
};

struct clause {
Expand All @@ -116,6 +116,7 @@ namespace arith {
vector<var_info> m_vars;
vector<clause> m_clauses;
svector<std::pair<lp::tv, euf::theory_var>> m_terms;
bool m_dscore_mode = false;


indexed_uint_set& unsat() { return m_bool_search->unsat_set(); }
Expand All @@ -136,7 +137,8 @@ namespace arith {
bool flip_clauses();
bool flip_dscore();
bool flip_dscore(unsigned cl);
bool flip(unsigned cl);
bool flip_clause(unsigned cl);
bool flip(ineq const& ineq);
int64_t dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); }
int64_t dtt(int64_t args, ineq const& ineq) const;
int64_t dtt(ineq const& ineq, var_t v, int64_t new_value) const;
Expand All @@ -145,6 +147,8 @@ namespace arith {
bool cm(ineq const& ineq, var_t v, int64_t& new_value);
int cm_score(var_t v, int64_t new_value);
void update(var_t v, int64_t new_value);
double dscore_reward(sat::bool_var v);
double dtt_reward(sat::bool_var v);
void paws();
int64_t dscore(var_t v, int64_t new_value) const;
void save_best_values();
Expand All @@ -163,11 +167,20 @@ namespace arith {

public:
sls(solver& s);
~sls() override {}
lbool operator ()(bool_vector& phase);
void set_bounds_begin();
void set_bounds_end(unsigned num_literals);
void set_bounds(euf::enode* n);
void set(sat::ddfw* d);

void init_search() override;
void finish_search() override;
void flip(sat::bool_var v) override;
double reward(sat::bool_var v) override;
void on_rescale() override;
void on_save_model() override;
void on_restart() override;
};

inline std::ostream& operator<<(std::ostream& out, sls::ineq const& ineq) {
Expand Down
Loading

0 comments on commit c1ecc49

Please sign in to comment.