Skip to content

Commit

Permalink
Resolve merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasColthurst committed Oct 2, 2024
2 parents 1300c6a + c87569a commit 57ef714
Show file tree
Hide file tree
Showing 17 changed files with 213 additions and 44 deletions.
10 changes: 10 additions & 0 deletions cxx/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ cc_library(
deps = [
":base",
":dirichlet_categorical",
"//emissions:string_alignment",
],
)

Expand Down Expand Up @@ -232,6 +233,15 @@ cc_test(
],
)

cc_test(
name = "string_nat_test",
srcs = ["string_nat_test.cc"],
deps = [
":string_nat",
"@boost//:test",
],
)

cc_test(
name = "zero_mean_normal_test",
srcs = ["zero_mean_normal_test.cc"],
Expand Down
3 changes: 3 additions & 0 deletions cxx/distributions/adapter.hh
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,8 @@ class DistributionAdapter : public Distribution<std::string> {
void init_theta(std::mt19937* prng) { d->init_theta(prng); }
void transition_theta(std::mt19937* prng) { d->transition_theta(prng); }

// TODO(thomaswc): Define nearest methods for the DistributionAdapter
// instantiations we use.

~DistributionAdapter() { delete d; }
};
6 changes: 6 additions & 0 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,11 @@ class Distribution {
// NonconjugateDistribution need define this.
virtual void transition_theta(std::mt19937* prng) {};

// Return the value nearest to x that is given non-zero probability by
// this distribution.
virtual T nearest(const T& x) const {
return x;
}

virtual ~Distribution() = default;
};
3 changes: 2 additions & 1 deletion cxx/distributions/crp.hh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// See LICENSE.txt

#pragma once
#include <map>
#include <random>
#include <unordered_map>
#include <unordered_set>
Expand All @@ -13,7 +14,7 @@ class CRP {
public:
double alpha = 1.; // concentration parameter
int N = 0; // number of customers
std::unordered_map<int, std::unordered_set<T_item>>
std::map<int, std::unordered_set<T_item>>
tables; // map from table id to set of customers
std::unordered_map<T_item, int> assignments; // map from customer to table id

Expand Down
11 changes: 11 additions & 0 deletions cxx/distributions/dirichlet_categorical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,14 @@ void DirichletCategorical::transition_hyperparameters(std::mt19937* prng) {
alpha = alphas[i];
}
}

int DirichletCategorical::nearest(const int& x) const {
if (x < 0) {
return 0;
}
// x can't be negative here, so safe to cast to size_t.
if (size_t(x) >= counts.size()) {
return counts.size() - 1;
}
return x;
}
2 changes: 2 additions & 0 deletions cxx/distributions/dirichlet_categorical.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ class DirichletCategorical : public Distribution<int> {
int sample(std::mt19937* prng);

void transition_hyperparameters(std::mt19937* prng);

int nearest(const int& x) const;
};
8 changes: 8 additions & 0 deletions cxx/distributions/dirichlet_categorical_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,11 @@ BOOST_AUTO_TEST_CASE(test_sample_and_log_prob) {
BOOST_TEST(abs(probs[i] - approx_p) <= 3 * stddev);
}
}

BOOST_AUTO_TEST_CASE(test_nearest) {
DirichletCategorical dc(12);

BOOST_TEST(dc.nearest(-5) == 0);
BOOST_TEST(dc.nearest(99) == 11);
BOOST_TEST(dc.nearest(7) == 7);
}
12 changes: 12 additions & 0 deletions cxx/distributions/string_nat.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <cctype>

#include "distributions/bigram.hh"

// A distribution over natural numbers represented as strings of digits.
Expand All @@ -12,4 +14,14 @@
class StringNat : public Bigram {
public:
StringNat(size_t _max_length = 20): Bigram(_max_length, '0', '9') {}

std::string nearest(const std::string& x) const {
std::string s;
for (const char& c : x) {
if (std::isdigit(c)) {
s += c;
}
}
return s;
}
};
24 changes: 24 additions & 0 deletions cxx/distributions/string_nat_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Apache License, Version 2.0, refer to LICENSE.txt

#define BOOST_TEST_MODULE test StringNat

#include "distributions/string_nat.hh"

#include <boost/test/included/unit_test.hpp>

BOOST_AUTO_TEST_CASE(test_simple) {
StringNat sn;

sn.incorporate("42");
sn.incorporate("0");
BOOST_TEST(sn.N == 2);
sn.unincorporate("0");
BOOST_TEST(sn.N == 1);
}

BOOST_AUTO_TEST_CASE(test_nearest) {
StringNat sn;

BOOST_TEST(sn.nearest("1234") == "1234");
BOOST_TEST(sn.nearest("a77z99") == "7799");
}
22 changes: 22 additions & 0 deletions cxx/distributions/stringcat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#include <algorithm>
#include <cstdlib>
#include <cassert>
#include <limits>
#include "distributions/stringcat.hh"
#include "emissions/string_alignment.hh"

int StringCat::string_to_index(const std::string& s) const {
auto it = std::find(strings.begin(), strings.end(), s);
Expand Down Expand Up @@ -35,3 +37,23 @@ std::string StringCat::sample(std::mt19937* prng) {
void StringCat::transition_hyperparameters(std::mt19937* prng) {
dc.transition_hyperparameters(prng);
}

std::string StringCat::nearest(const std::string& x) const {
if (std::find(strings.begin(), strings.end(), x) != strings.end()) {
return x;
}

const std::string *nearest = &(strings[0]);
double lowest_distance = std::numeric_limits<double>::max();
for (const std::string& s : strings) {
std::vector<StrAlignment> alignments;
topk_alignments(1, s, x, edit_distance, &alignments);
double d = alignments[0].cost;
if (d < lowest_distance) {
lowest_distance = d;
nearest = &s;
}
}

return *nearest;
}
2 changes: 2 additions & 0 deletions cxx/distributions/stringcat.hh
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@ class StringCat : public Distribution<std::string> {
void set_alpha(double alphat);

void transition_hyperparameters(std::mt19937* prng);

std::string nearest(const std::string& x) const;
};
3 changes: 3 additions & 0 deletions cxx/distributions/stringcat_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,7 @@ BOOST_AUTO_TEST_CASE(test_simple) {
auto it = std::find(strings.begin(), strings.end(), samp);
bool found = (it != strings.end());
BOOST_TEST(found);

BOOST_TEST(sc.nearest("test") == "test");
BOOST_TEST(sc.nearest("otter") == "other");
}
47 changes: 34 additions & 13 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ double GenDB::logp_score() const {

void GenDB::incorporate(
std::mt19937* prng,
const std::pair<int, std::map<std::string, ObservationVariant>>& row) {
const std::pair<int, std::map<std::string, ObservationVariant>>& row,
bool new_rows_have_unique_entities) {
int id = row.first;

// TODO: Consider not walking the DAG when new_rows_have_unique_entities =
// True.

// Maps a query relation name to an observed value.
std::map<std::string, ObservationVariant> vals = row.second;

Expand All @@ -55,7 +59,8 @@ void GenDB::incorporate(
schema.query.fields.at(query_rel).class_path;
T_items items =
sample_entities_relation(prng, schema.query.record_class,
class_path.cbegin(), class_path.cend(), id);
class_path.cbegin(), class_path.cend(), id,
new_rows_have_unique_entities);

// Incorporate the items/value into the query relation.
incorporate_query_relation(prng, query_rel, items, val);
Expand All @@ -69,13 +74,15 @@ void GenDB::incorporate(
T_items GenDB::sample_entities_relation(
std::mt19937* prng, const std::string& class_name,
std::vector<std::string>::const_iterator class_path_start,
std::vector<std::string>::const_iterator class_path_end, int class_item) {
std::vector<std::string>::const_iterator class_path_end,
int class_item, bool new_rows_have_unique_entities) {
if (class_path_end - class_path_start == 1) {
// The last item in class_path is the class from which the queried attribute
// is observed (for which there's a corresponding clean relation, observing
// the attribute from the class). We need to DFS-traverse the class's
// parents, similar to PCleanSchemaHelper::compute_domains_for.
return sample_class_ancestors(prng, class_name, class_item);
return sample_class_ancestors(prng, class_name, class_item,
new_rows_have_unique_entities);
}

// These are noisy relation domains along the path from the latent cleanly-
Expand All @@ -90,11 +97,13 @@ T_items GenDB::sample_entities_relation(
std::tuple<std::string, std::string, int> ref_key = {class_name, ref_field,
class_item};
if (!reference_values.contains(ref_key)) {
sample_and_incorporate_reference(prng, ref_key, ref_class);
sample_and_incorporate_reference(prng, ref_key, ref_class,
new_rows_have_unique_entities);
}
T_items items =
sample_entities_relation(prng, ref_class, ++class_path_start,
class_path_end, reference_values.at(ref_key));
sample_entities_relation(
prng, ref_class, ++class_path_start, class_path_end,
reference_values.at(ref_key), new_rows_have_unique_entities);
// The order of the items corresponds to the order of the relation's domains,
// with the class (domain) corresponding to the primary key placed last on the
// list.
Expand All @@ -105,9 +114,19 @@ T_items GenDB::sample_entities_relation(
void GenDB::sample_and_incorporate_reference(
std::mt19937* prng,
const std::tuple<std::string, std::string, int>& ref_key,
const std::string& ref_class) {
const std::string& ref_class, bool new_rows_have_unique_entities) {
auto [class_name, ref_field, class_item] = ref_key;
int new_val = domain_crps[ref_class].sample(prng);
int new_val;
if (new_rows_have_unique_entities) {
auto it = domain_crps[ref_class].tables.rbegin();
if (it == domain_crps[ref_class].tables.rend()) {
new_val = 0;
} else {
new_val = it->first + 1;
}
} else {
new_val = domain_crps[ref_class].sample(prng);
}

// Generate a unique ID for the sample and incorporate it into the
// domain CRP.
Expand Down Expand Up @@ -152,7 +171,7 @@ void GenDB::incorporate_query_relation(std::mt19937* prng,
// reference_values table/entity CRPs) if necessary.
T_items GenDB::sample_class_ancestors(std::mt19937* prng,
const std::string& class_name,
int class_item) {
int class_item, bool new_rows_have_unique_entities) {
T_items items;
assert(schema.classes.contains(class_name));
PCleanClass c = schema.classes.at(class_name);
Expand All @@ -164,10 +183,12 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng,
std::tuple<std::string, std::string, int> ref_key = {class_name, name,
class_item};
if (!reference_values.contains(ref_key)) {
sample_and_incorporate_reference(prng, ref_key, cv->class_name);
sample_and_incorporate_reference(
prng, ref_key, cv->class_name, new_rows_have_unique_entities);
}
T_items ref_items = sample_class_ancestors(prng, cv->class_name,
reference_values.at(ref_key));
T_items ref_items = sample_class_ancestors(
prng, cv->class_name, reference_values.at(ref_key),
new_rows_have_unique_entities);
items.insert(items.end(), ref_items.begin(), ref_items.end());
}
}
Expand Down
19 changes: 14 additions & 5 deletions cxx/gendb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,16 @@ class GenDB {
double logp_score() const;

// Incorporates a row of observed data into the GenDB instance.
// When new_rows_have_unique_entities = True, each part of the row is assumed
// to correspond to a new entity. In particular, if two entities are added
// to the same domain in the course of adding a row, those entities will also
// be unique.
// When new_rows_have_unique_entities = False, entity ids for each row part
// is sampled from the correpsonding CRP.
void incorporate(
std::mt19937* prng,
const std::pair<int, std::map<std::string, ObservationVariant>>& row);
const std::pair<int, std::map<std::string, ObservationVariant>>& row,
bool new_rows_have_unique_entities);

// Incorporates a single element of a row of observed data.
void incorporate_query_relation(std::mt19937* prng,
Expand All @@ -35,18 +42,20 @@ class GenDB {
void sample_and_incorporate_reference(
std::mt19937* prng,
const std::tuple<std::string, std::string, int>& ref_key,
const std::string& ref_class);
const std::string& ref_class, bool new_rows_have_unique_entities);

// Samples a set of entities in the domains of the relation corresponding to
// class_path.
T_items sample_entities_relation(
std::mt19937* prng, const std::string& class_name,
std::vector<std::string>::const_iterator class_path_start,
std::vector<std::string>::const_iterator class_path_end, int class_item);
std::vector<std::string>::const_iterator class_path_end,
int class_item, bool new_rows_have_unique_entities);

// Sample items from a class' ancestors (recursive reference fields).
T_items sample_class_ancestors(std::mt19937* prng,
const std::string& class_name, int class_item);
T_items sample_class_ancestors(
std::mt19937* prng, const std::string& class_name, int class_item,
bool new_rows_have_unique_entities);

// Populates "items" with entities by walking the DAG of reference indices,
// starting with "ind".
Expand Down
Loading

0 comments on commit 57ef714

Please sign in to comment.