Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions src/comp_analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ THE SOFTWARE.
#include <algorithm>
#include <cstdint>
#include <numeric>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include "chibihash64.h"

using namespace GanakInt;

Expand All @@ -45,6 +49,7 @@ void CompAnalyzer::initialize(
const LiteralIndexedVector<LitWatchList> & watches, // binary clauses
ClauseAllocator const* alloc, const vector<ClauseOfs>& _long_irred_cls) // longer-than-2-long clauses
{
watches_ = &watches;
max_var = watches.end_lit().var() - 1;
comp_vars.reserve(max_var + 1);
var_freq_scores.resize(max_var + 1, 0);
Expand Down Expand Up @@ -408,6 +413,256 @@ end_sat:;
<< comp_vars.size() << " long");
}

// Computes Weisfeiler-Lehman canonical component information for cache lookup.
//
// Approach:
// 1. Build a clause->variable mapping restricted to this component's vars/clauses.
// 2. Compute an initial "color" for each variable: (long_degree, binary_degree, is_indep).
// 3. Run one round of WL refinement: each variable's new color incorporates the sorted
// multiset of its long-clause neighbours' initial colors.
// 4. Sort variables by WL color (original var_id as tiebreaker) to get canonical order.
// 5. Express every clause (long + binary) in terms of canonical variable indices and sort.
// 6. Hash the resulting canonical clause list to produce a structure-invariant cache key.
//
// The result is invariant to variable/clause ID renaming: two structurally isomorphic
// components will produce identical sorted_canon_clauses and the same hash, enabling a
// cache hit even if they involve completely different variable numberings.
CanonInfo CompAnalyzer::compute_canon_info(const Comp& comp,
uint64_t hash_seed,
uint32_t threshold) const {
CanonInfo info;
const uint32_t n = comp.nVars();
// WL canonicalization is only sound for unweighted counting (mode 0).
// For weighted modes, each variable has an individual weight; two components
// with the same clause structure but different per-variable weights have
// different weighted counts, so structural isomorphism does not imply
// cache equivalence. weighted() is false only for FGenMpz (mode 0).
if (threshold == 0 || n > threshold || n == 0 || counter->weighted()) return info;

// --- Build membership lookups ---

// comp clause set for O(1) membership test
std::unordered_set<uint32_t> comp_clause_set;
comp_clause_set.reserve(comp.num_long_cls() * 2);
for (auto it = comp.cls_begin(); *it != sentinel; ++it) comp_clause_set.insert(*it);

// var_id -> position in comp (0..n-1)
std::unordered_map<uint32_t, uint32_t> var_to_pos;
var_to_pos.reserve(n * 2);
for (uint32_t i = 0; i < n; ++i) var_to_pos[comp.vars_begin()[i]] = i;

// --- Compute degree-based initial color per variable (position) ---

// long_deg[i] = number of comp clauses containing position-i variable
// bin_deg[i] = number of comp variables that position-i variable shares a binary clause with
vector<uint32_t> long_deg(n, 0);
vector<uint32_t> bin_deg(n, 0);

// clause_id -> list of comp-variable positions that appear in it (for WL neighbor graph)
std::unordered_map<uint32_t, vector<uint32_t>> clause_to_pos;
clause_to_pos.reserve(comp.num_long_cls() * 2);

// clause_id -> all Lits with signs (for polarity-aware canonical form)
// For ternary clauses: accumulated across all 3 variable visits.
// For long clauses: read from long_clauses_data on first encounter.
std::unordered_map<uint32_t, vector<Lit>> clause_all_lits;
clause_all_lits.reserve(comp.num_long_cls() * 2);

for (uint32_t i = 0; i < n; ++i) {
const uint32_t v = comp.vars_begin()[i];

// Long clauses: iterate over all long clauses v appears in.
const ClData* longs = holder.begin_long(v);
const ClData* longs_end = longs + holder.orig_size_long(v);
for (const ClData* d = longs; d != longs_end; ++d) {
if (!comp_clause_set.count(d->id)) continue;
clause_to_pos[d->id].push_back(i);
++long_deg[i];

auto& lits = clause_all_lits[d->id];
if (d->id < max_tri_clid) {
// Ternary clause: contribute the 2 "other" literals from this visit.
// After all 3 vars are visited, lits will contain all 3 signed literals.
const Lit l1 = d->get_lit1();
const Lit l2 = d->get_lit2();
if (std::find(lits.begin(), lits.end(), l1) == lits.end()) lits.push_back(l1);
if (std::find(lits.begin(), lits.end(), l2) == lits.end()) lits.push_back(l2);
} else {
// Long clause: read all signed literals from the literal pool on first encounter.
if (lits.empty()) {
const Lit* start = long_clauses_data.data() + d->off;
for (const Lit* it_l = start; *it_l != SENTINEL_LIT; ++it_l) lits.push_back(*it_l);
}
}
}

// Binary degree (polarity-blind, used for WL initial color only).
const uint32_t* bins = holder.begin_bin(v);
for (uint32_t j = 0; j < holder.orig_size_bin(v); ++j) {
if (var_to_pos.count(bins[j])) ++bin_deg[i];
}
}

// --- Initial WL color per position ---
using Color3 = std::tuple<uint32_t, uint32_t, uint32_t>;
vector<Color3> init_color(n);
for (uint32_t i = 0; i < n; ++i) {
const bool is_indep = (comp.vars_begin()[i] < indep_support_end);
init_color[i] = {long_deg[i], bin_deg[i], static_cast<uint32_t>(is_indep)};
}
VERBOSE_DEBUG_DO(
cout << "WL canon: nVars=" << n
<< " nLongCls=" << clause_all_lits.size()
<< " initial colors (var longdeg bindeg isindep):";
for (uint32_t i = 0; i < n; ++i)
cout << " [" << comp.vars_begin()[i] << " "
<< long_deg[i] << " " << bin_deg[i] << " "
<< (comp.vars_begin()[i] < indep_support_end ? 1 : 0) << "]";
cout << endl;
);

// --- Build long-clause neighbour adjacency (for WL round) ---
// cl_neighbors[i] = positions of variables that share a long clause with position i
vector<vector<uint32_t>> cl_neighbors(n);
for (auto& [cl_id, positions] : clause_to_pos) {
for (uint32_t u : positions)
for (uint32_t w : positions)
if (u != w) cl_neighbors[u].push_back(w);
}

// --- One round of WL refinement ---
// wl1[i] = hash of (init_color[i], sorted multiset of init_colors of clause-neighbours)
vector<uint64_t> wl1(n);
for (uint32_t i = 0; i < n; ++i) {
vector<Color3> ncolors;
ncolors.reserve(cl_neighbors[i].size());
for (uint32_t j : cl_neighbors[i]) ncolors.push_back(init_color[j]);
sort(ncolors.begin(), ncolors.end());

// Mix into a 64-bit hash using simple polynomial mixing
uint64_t h = (static_cast<uint64_t>(get<0>(init_color[i])) * 2654435761ULL)
^ (static_cast<uint64_t>(get<1>(init_color[i])) * 40503ULL)
^ (static_cast<uint64_t>(get<2>(init_color[i])) * 2246822519ULL);
for (const auto& [ld, bd, indp] : ncolors) {
h ^= (h >> 16) * 0x45d9f3bULL;
h += (static_cast<uint64_t>(ld) * 2654435761ULL)
^ (static_cast<uint64_t>(bd) * 40503ULL)
^ (static_cast<uint64_t>(indp));
}
wl1[i] = h;
}

VERBOSE_DEBUG_DO(
cout << "WL canon: wl1 colors (var wl1hash):";
for (uint32_t i = 0; i < n; ++i)
cout << " [" << comp.vars_begin()[i] << " 0x" << std::hex << wl1[i] << std::dec << "]";
cout << endl;
);

// --- Sort variables by (wl1, init_color, var_id) to get canonical order ---
vector<uint32_t> perm(n);
iota(perm.begin(), perm.end(), 0);
sort(perm.begin(), perm.end(), [&](uint32_t a, uint32_t b) {
if (wl1[a] != wl1[b]) return wl1[a] < wl1[b];
if (init_color[a] != init_color[b]) return init_color[a] < init_color[b];
return comp.vars_begin()[a] < comp.vars_begin()[b]; // stable tiebreak
});

// canon_vars[i] = original var_id at canonical position i
info.canon_vars.resize(n);
for (uint32_t i = 0; i < n; ++i) info.canon_vars[i] = comp.vars_begin()[perm[i]];

// canon_is_indep[i] = 1 if canonical position i is in the independent support, else 0.
// Must be included in the hash/equality data so that two structurally isomorphic
// components that differ only in their indep-support assignments are not confused.
info.canon_is_indep.resize(n);
for (uint32_t i = 0; i < n; ++i)
info.canon_is_indep[i] = static_cast<uint32_t>(comp.vars_begin()[perm[i]] < indep_support_end);

// orig_pos -> canonical index
vector<uint32_t> orig_to_canon(n);
for (uint32_t i = 0; i < n; ++i) orig_to_canon[perm[i]] = i;

// --- Build polarity-aware canonical clause representations ---
// Canonical literal index: 2 * canon_pos + (uint32_t)lit.sign()
// (sign()=false → negative literal, sign()=true → positive literal)
//
// Long/ternary: from clause_all_lits, filter to in-component (unknown) vars.
info.sorted_canon_clauses.reserve(clause_all_lits.size() + n);
for (auto& [cl_id, lits] : clause_all_lits) {
vector<uint32_t> cv;
cv.reserve(lits.size());
for (const Lit l : lits) {
auto it = var_to_pos.find(l.var());
if (it == var_to_pos.end()) continue; // satisfied/false lit, skip
cv.push_back(2 * orig_to_canon[it->second] + static_cast<uint32_t>(l.sign()));
}
if (cv.size() < 2) continue; // should not happen for in-comp clauses
sort(cv.begin(), cv.end());
info.sorted_canon_clauses.push_back(std::move(cv));
}

// Binary clauses: use watches_ directly to get full literal polarities.
// watches_[Lit(v,s)] contains binary clauses of the form (Lit(v,s) ∨ bincl.lit()).
// Deduplicate via a 64-bit key (packed canonical lit-index pair, lo<<32|hi).
std::unordered_set<uint64_t> seen_bin;
seen_bin.reserve(n * 4);
for (uint32_t pos_i = 0; pos_i < n; ++pos_i) {
const uint32_t v = comp.vars_begin()[pos_i];
for (const bool s : {false, true}) {
for (const auto& bincl : (*watches_)[Lit(v, s)].binaries) {
if (!bincl.irred()) continue;
const Lit other = bincl.lit();
auto it = var_to_pos.find(other.var());
if (it == var_to_pos.end()) continue; // other end not in comp
const uint32_t cli = 2 * orig_to_canon[pos_i] + static_cast<uint32_t>(s);
const uint32_t clj = 2 * orig_to_canon[it->second] + static_cast<uint32_t>(other.sign());
const uint32_t lo = std::min(cli, clj);
const uint32_t hi = std::max(cli, clj);
const uint64_t key = (static_cast<uint64_t>(lo) << 32) | hi;
if (seen_bin.insert(key).second)
info.sorted_canon_clauses.push_back({lo, hi});
}
}
}

// --- Sort all canonical clauses lexicographically ---
sort(info.sorted_canon_clauses.begin(), info.sorted_canon_clauses.end());

// --- Compute structural hash of (nVars, canonical clauses, is_indep profile) ---
// Encoding: [n, n_total_clauses, for each clause: (size, v0, v1, ...), is_indep[0..n-1]]
// The is_indep profile distinguishes components whose clause structures are isomorphic
// but differ in which canonical positions belong to the independent (projection) support.
vector<uint32_t> hdata;
hdata.reserve(2 + info.sorted_canon_clauses.size() * 4 + n);
hdata.push_back(n);
hdata.push_back(static_cast<uint32_t>(info.sorted_canon_clauses.size()));
for (const auto& cv : info.sorted_canon_clauses) {
hdata.push_back(static_cast<uint32_t>(cv.size()));
for (uint32_t idx : cv) hdata.push_back(idx);
}
for (uint32_t i = 0; i < n; ++i) hdata.push_back(info.canon_is_indep[i]);
info.hash = chibihash64(hdata.data(), hdata.size() * sizeof(uint32_t), hash_seed);

VERBOSE_DEBUG_DO(
cout << "WL canon: final hash=0x" << std::hex << info.hash << std::dec
<< " nclauses=" << info.sorted_canon_clauses.size()
<< " canonical clauses:";
for (const auto& cv : info.sorted_canon_clauses) {
cout << " (";
for (uint32_t idx = 0; idx < cv.size(); ++idx) {
if (idx) cout << ",";
cout << cv[idx];
}
cout << ")";
}
cout << endl;
);

info.valid = true;
return info;
}

// There is exactly ONE of these
CompAnalyzer::CompAnalyzer(
const LiteralIndexedVector<TriValue> & lit_values,
Expand Down
16 changes: 16 additions & 0 deletions src/comp_analyzer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ THE SOFTWARE.
#include "statistics.hpp"
#include "comp_types/comp.hpp"
#include "comp_types/comp_archetype.hpp"
#include "comp_types/canon_info.hpp"

#include <climits>
#include <cstring>
Expand Down Expand Up @@ -95,6 +96,10 @@ struct MyHolder {
auto start = data[v*hstride+offset+0];
return data.get() + start;
}
const uint32_t* begin_bin(uint32_t v) const {
auto start = data[v*hstride+offset+0];
return data.get() + start;
}
uint32_t size_bin(uint32_t v) const { return data[v*hstride+offset+1];}
uint32_t& size_bin(uint32_t v) { return data[v*hstride+offset+1];}
uint32_t orig_size_bin(uint32_t v) const { return data[v*hstride+offset+2];}
Expand All @@ -107,6 +112,10 @@ struct MyHolder {
auto start = data[v*hstride+offset+3];
return reinterpret_cast<ClData*>(data.get() + start);
}
const ClData* begin_long(uint32_t v) const {
auto start = data[v*hstride+offset+3];
return reinterpret_cast<const ClData*>(data.get() + start);
}
uint32_t size_long(uint32_t v) const { return data[v*hstride+offset+4];}
uint32_t& size_long(uint32_t v) { return data[v*hstride+offset+4];}
uint32_t orig_size_long(uint32_t v) const { return data[v*hstride+offset+5];}
Expand Down Expand Up @@ -179,6 +188,12 @@ class CompAnalyzer {
uint32_t get_max_var() const { return max_var; }
CompArchetype& get_archetype() { return archetype; }

// Compute WL-based canonical information for comp (only if nVars <= threshold).
// Must be called immediately after make_comp_from_archetype() and before the
// next explore_comp(), because it relies on the current holder state.
// Returns a CanonInfo with valid=false if nVars > threshold or threshold == 0.
CanonInfo compute_canon_info(const Comp& comp, uint64_t hash_seed, uint32_t threshold) const;

private:
// the id of the last clause
// note that clause ID is the clause number,
Expand All @@ -191,6 +206,7 @@ class CompAnalyzer {
MyHolder holder;
vector<Lit> long_clauses_data;
const LiteralIndexedVector<TriValue> & values;
const LiteralIndexedVector<LitWatchList>* watches_ = nullptr;

const CounterConfiguration& conf;
const uint32_t indep_support_end;
Expand Down
5 changes: 3 additions & 2 deletions src/comp_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ class CompCache final: public CompCacheIF {
T* comp = reinterpret_cast<T*>(c);
return comp->extra_bytes();
}
void* create_new_comp(const Comp &comp, uint64_t hash_seed, const BPCSizes& bpc) override {
return new T(comp, hash_seed, bpc);
void* create_new_comp(const Comp &comp, uint64_t hash_seed, const BPCSizes& bpc,
const CanonInfo* canon = nullptr) override {
return new T(comp, hash_seed, bpc, canon);
}

[[nodiscard]] uint64_t get_max_num_entries() const override { return entry_base.size(); }
Expand Down
4 changes: 3 additions & 1 deletion src/comp_cache_if.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ THE SOFTWARE.

#include "common.hpp"
#include "comp_types/cacheable_comp.hpp"
#include "comp_types/canon_info.hpp"
#include "statistics.hpp"
#include <gmpxx.h>
#include "stack.hpp"
Expand All @@ -40,7 +41,8 @@ class CompCacheIF {
virtual CacheEntryID add_new_comp(void* comp, CacheEntryID super_comp_id) = 0;
virtual uint64_t get_extra_bytes(void* comp) const = 0;
virtual bool find_comp_and_incorporate_cnt(StackLevel &top, const uint32_t nvars, const void* comp) = 0;
virtual void* create_new_comp(const Comp &comp, uint64_t hash_seed, const BPCSizes& bpc) = 0;
virtual void* create_new_comp(const Comp &comp, uint64_t hash_seed, const BPCSizes& bpc,
const CanonInfo* canon = nullptr) = 0;
virtual void free_comp(void* comp) = 0;

virtual void make_entry_deletable(CacheEntryID id) = 0;
Expand Down
Loading
Loading