diff --git a/cxx/pclean/pclean.cc b/cxx/pclean/pclean.cc index 3b9c005..a9f40a6 100644 --- a/cxx/pclean/pclean.cc +++ b/cxx/pclean/pclean.cc @@ -99,8 +99,8 @@ int main(int argc, char** argv) { if (result.count("output") > 0) { std::string out_fn = result["output"].as(); std::cout << "Savings results to " << out_fn << "\n"; - // TODO(thomaswc): Fix this. - // to_txt(out_fn, gendb.hirm, encoding); + T_encoding encoding = make_dummy_encoding_from_gendb(gendb); + to_txt(out_fn, *(gendb.hirm), encoding); } std::string heldout_fn = result["heldout"].as(); diff --git a/cxx/pclean/pclean_lib.cc b/cxx/pclean/pclean_lib.cc index b3a40f7..34421d4 100644 --- a/cxx/pclean/pclean_lib.cc +++ b/cxx/pclean/pclean_lib.cc @@ -84,3 +84,17 @@ DataFrame make_pclean_samples(int num_samples, int start_row, GenDB *gendb, return df; } +T_encoding make_dummy_encoding_from_gendb(const GenDB& gendb) { + T_encoding_f item_to_code; + T_encoding_r code_to_item; + + for (const auto& [domain, crp] : gendb.domain_crps) { + for (int i = 0; i <= crp.max_table(); ++i) { + // TODO: Make the auto-generated string include the row number + // and CSV field name, for ease in debugging and visualizations. + code_to_item[domain][i] = domain + ":" + std::to_string(i); + } + } + + return std::make_pair(item_to_code, code_to_item); +} diff --git a/cxx/pclean/pclean_lib.hh b/cxx/pclean/pclean_lib.hh index 63cc038..a763294 100644 --- a/cxx/pclean/pclean_lib.hh +++ b/cxx/pclean/pclean_lib.hh @@ -22,3 +22,7 @@ void incorporate_observations(std::mt19937* prng, // All existing rows added to gendb should have ids < start_row. DataFrame make_pclean_samples(int num_samples, int start_row, GenDB *gendb, std::mt19937* prng); + +// Makes an encoding from a GenDB. The i-th entity from domain "domain" +// is given the name "domain:i". +T_encoding make_dummy_encoding_from_gendb(const GenDB& gendb); diff --git a/cxx/pclean/pclean_lib_test.cc b/cxx/pclean/pclean_lib_test.cc index 03f928e..ea8acdf 100644 --- a/cxx/pclean/pclean_lib_test.cc +++ b/cxx/pclean/pclean_lib_test.cc @@ -166,3 +166,85 @@ observe BOOST_TEST(samples.data["City"].size() == 10); BOOST_TEST(samples.data["State"].size() == 10); } + +BOOST_AUTO_TEST_CASE(test_make_dummy_encoding_from_gendb) { + std::mt19937 prng; + + std::stringstream ss(R"""( +class School + name ~ string + degree_dist ~ categorical(k=100) + +class Physician + school ~ School + degree ~ stringcat(strings="MD PT NP DO PHD") + specialty ~ stringcat(strings="Family Med:Internal Med:Physical Therapy", delim=":") + # observed_degree ~ maybe_swap(degree) + +class City + name ~ string + state ~ stringcat(strings="AL AK AZ AR CA CO CT DE DC FL GA HI ID IL IN IA KS KY LA ME MD MA MI MN MS MO MT NE NV NH NJ NM NY NC ND OH OK OR PA RI SC SD TN TX UT VT VA WA WV WI WY") + +class Practice + city ~ City + +class Record + physician ~ Physician + location ~ Practice + +observe + physician.specialty as Specialty + physician.school.name as School + physician.degree as Degree + location.city.name as City + location.city.state as State + from Record +)"""); + + PCleanSchema pclean_schema; + BOOST_TEST(read_schema(ss, &pclean_schema)); + + GenDB gendb(&prng, pclean_schema); + + T_encoding enc = make_dummy_encoding_from_gendb(gendb); + + BOOST_TEST(enc.second.size() == 0); + + std::map obs = { + {"Specialty", "Internal Med"}, + {"School", "Harvard"}, + {"Degree", "MD"}, + {"City", "Cambridge"}, + {"State", "MA"}}; + + gendb.incorporate(&prng, {0, obs}, true); + T_encoding enc2 = make_dummy_encoding_from_gendb(gendb); + + BOOST_TEST(enc2.second["School"][0] == "School:0"); + BOOST_TEST(enc2.second["Physician"][0] == "Physician:0"); + BOOST_TEST(enc2.second["City"][0] == "City:0"); + BOOST_TEST(enc2.second["Practice"][0] == "Practice:0"); + + BOOST_TEST(enc2.second["School"].size() == 1); + + for (int i = 1; i < 6; ++i) { + gendb.incorporate(&prng, {i, obs}, true); + } + + T_encoding enc3 = make_dummy_encoding_from_gendb(gendb); + BOOST_TEST(enc3.second["School"].size() == 6); + BOOST_TEST(enc3.second["School"][0] == "School:0"); + BOOST_TEST(enc3.second["School"][1] == "School:1"); + BOOST_TEST(enc3.second["School"][2] == "School:2"); + BOOST_TEST(enc3.second["School"][3] == "School:3"); + BOOST_TEST(enc3.second["School"][4] == "School:4"); + BOOST_TEST(enc3.second["School"][5] == "School:5"); + + // Test that we got all the entities. + for (const auto& [domain, crp] : gendb.domain_crps) { + for (int i = 0; i <= crp.max_table(); ++i) { + BOOST_TEST(enc3.second[domain].contains(i)); + } + } + +} diff --git a/cxx/util_io.hh b/cxx/util_io.hh index b6b9f59..a6d8779 100644 --- a/cxx/util_io.hh +++ b/cxx/util_io.hh @@ -8,7 +8,7 @@ typedef std::map> T_encoding_f; typedef std::map> T_encoding_r; -typedef std::tuple T_encoding; +typedef std::pair T_encoding; // Load the schema file from path. Exits if the schema file can't be parsed. T_schema load_schema(const std::string& path);