From ca3bdae5c96285176860d039adeef6ac6be00393 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Tue, 17 Nov 2020 22:55:55 -0500 Subject: [PATCH 1/2] Filter transactions to only those referencing open accounts Add test to ensure that closed accounts are not predicted and adjust existing tests so that all desired accounts are opened. --- smart_importer/predictor.py | 29 +++++++++++++++++++++++++---- tests/data/multiaccounts.beancount | 6 ++++++ tests/data/simple.beancount | 5 +++++ tests/data/single-account.beancount | 3 +++ tests/pipelines_test.py | 17 +++++++++++------ tests/predictors_test.py | 26 ++++++++++++++++++++++++++ 6 files changed, 76 insertions(+), 10 deletions(-) diff --git a/smart_importer/predictor.py b/smart_importer/predictor.py index b03b0f5..85528c0 100644 --- a/smart_importer/predictor.py +++ b/smart_importer/predictor.py @@ -9,7 +9,8 @@ from beancount.core.data import ALL_DIRECTIVES from beancount.core.data import filter_txns -from beancount.core.data import Transaction +from beancount.core.data import sorted as beancount_sorted +from beancount.core.data import Transaction, Open, Close from sklearn.pipeline import FeatureUnion from sklearn.pipeline import make_pipeline from sklearn.svm import SVC @@ -41,6 +42,7 @@ class EntryPredictor(ImporterHook): def __init__(self, predict=True, suggest=False, overwrite=False): super().__init__() self.training_data = None + self.open_accounts = {} self.pipeline = None self.is_fitted = False self.lock = threading.Lock() @@ -69,9 +71,24 @@ def __call__(self, importer, file, imported_entries, existing_entries): self.train_pipeline() return self.process_entries(imported_entries) + def load_open_accounts(self, existing_entries): + """Return map of accounts which have been opened but not closed.""" + account_map = {} + if not existing_entries: + return + + for entry in beancount_sorted(existing_entries): + if isinstance(entry, Open): + account_map[entry.account] = entry + elif isinstance(entry, Close): + account_map.pop(entry.account) + + self.open_accounts = account_map + def load_training_data(self, existing_entries): """Load training data, i.e., a list of Beancount entries.""" training_data = existing_entries or [] + self.load_open_accounts(existing_entries) training_data = list(filter_txns(training_data)) length_all = len(training_data) training_data = [ @@ -86,9 +103,13 @@ def load_training_data(self, existing_entries): def training_data_filter(self, txn): """Filter function for the training data.""" - if self.account: - return any([pos.account == self.account for pos in txn.postings]) - return True + found_import_account = False + for pos in txn.postings: + if pos.account not in self.open_accounts: + return False + if self.account == pos.account: + found_import_account = True + return found_import_account or not self.account @property def targets(self): diff --git a/tests/data/multiaccounts.beancount b/tests/data/multiaccounts.beancount index 94ee7d6..acc015e 100644 --- a/tests/data/multiaccounts.beancount +++ b/tests/data/multiaccounts.beancount @@ -3,6 +3,12 @@ Assets:US:EUR -2.50 USD # TRAINING +2016-01-01 open Assets:US:CHF USD +2016-01-01 open Assets:US:EUR USD +2016-01-01 open Assets:US:USD USD +2016-01-01 open Expenses:Food:Swiss USD +2016-01-01 open Expenses:Food:Europe USD +2016-01-01 open Expenses:Food:Usa USD 2016-01-06 * "Foo" Assets:US:CHF -2.50 USD Expenses:Food:Swiss diff --git a/tests/data/simple.beancount b/tests/data/simple.beancount index 6b343b3..3dabd9b 100644 --- a/tests/data/simple.beancount +++ b/tests/data/simple.beancount @@ -20,6 +20,11 @@ Assets:US:BofA:Checking -5.00 USD # TRAINING +2016-01-01 open Assets:US:BofA:Checking USD +2016-01-01 open Expenses:Food:Coffee USD +2016-01-01 open Expenses:Food:Groceries USD +2016-01-01 open Expenses:Food:Restaurant USD + 2016-01-06 * "Farmer Fresh" "Buying groceries" Assets:US:BofA:Checking -2.50 USD Expenses:Food:Groceries diff --git a/tests/data/single-account.beancount b/tests/data/single-account.beancount index d40ad0c..45ff3df 100644 --- a/tests/data/single-account.beancount +++ b/tests/data/single-account.beancount @@ -3,6 +3,9 @@ Assets:US:BofA:Checking -2.50 USD # TRAINING +2016-01-01 open Assets:US:BofA:Checking USD +2016-01-01 open Expenses:Food:Groceries USD + 2016-01-06 * "Farmer Fresh" "Buying groceries" Assets:US:BofA:Checking -2.50 USD Expenses:Food:Groceries diff --git a/tests/pipelines_test.py b/tests/pipelines_test.py index 38b7fe7..43536d9 100644 --- a/tests/pipelines_test.py +++ b/tests/pipelines_test.py @@ -6,6 +6,10 @@ TEST_DATA, _, __ = parser.parse_string( """ +2016-01-01 open Assets:US:BofA:Checking USD +2016-01-01 open Expenses:Food:Groceries USD +2016-01-01 open Expenses:Food:Coffee USD + 2016-01-06 * "Farmer Fresh" "Buying groceries" Assets:US:BofA:Checking -10.00 USD @@ -22,11 +26,12 @@ Expenses:Food:Coffee """ ) -TEST_TRANSACTION = TEST_DATA[0] +TEST_TRANSACTIONS = TEST_DATA[3:] +TEST_TRANSACTION = TEST_TRANSACTIONS[0] def test_get_payee(): - assert AttrGetter("payee").transform(TEST_DATA) == [ + assert AttrGetter("payee").transform(TEST_TRANSACTIONS) == [ "Farmer Fresh", "Starbucks", "Farmer Fresh", @@ -35,7 +40,7 @@ def test_get_payee(): def test_get_narration(): - assert AttrGetter("narration").transform(TEST_DATA) == [ + assert AttrGetter("narration").transform(TEST_TRANSACTIONS) == [ "Buying groceries", "Coffee", "Groceries", @@ -44,10 +49,10 @@ def test_get_narration(): def test_get_metadata(): - txn = TEST_DATA[0] + txn = TEST_TRANSACTION txn.meta["attr"] = "value" assert AttrGetter("meta.attr").transform([txn]) == ["value"] - assert AttrGetter("meta.attr", "default").transform(TEST_DATA) == [ + assert AttrGetter("meta.attr", "default").transform(TEST_TRANSACTIONS) == [ "value", "default", "default", @@ -56,4 +61,4 @@ def test_get_metadata(): def test_get_day_of_month(): - assert AttrGetter("date.day").transform(TEST_DATA) == [6, 7, 7, 8] + assert AttrGetter("date.day").transform(TEST_TRANSACTIONS) == [6, 7, 7, 8] diff --git a/tests/predictors_test.py b/tests/predictors_test.py index fd6563d..3acf78c 100644 --- a/tests/predictors_test.py +++ b/tests/predictors_test.py @@ -29,11 +29,21 @@ 2017-01-12 * "Uncle Boons" "" Assets:US:BofA:Checking -27.00 USD + +2017-01-13 * "Gas Quick" + Assets:US:BofA:Checking -17.45 USD """ ) TRAINING_DATA, _, __ = parser.parse_string( """ +2016-01-01 open Assets:US:BofA:Checking USD +2016-01-01 open Expenses:Food:Coffee USD +2016-01-01 open Expenses:Auto:Diesel USD +2016-01-01 open Expenses:Auto:Gas USD +2016-01-01 open Expenses:Food:Groceries USD +2016-01-01 open Expenses:Food:Restaurant USD + 2016-01-06 * "Farmer Fresh" "Buying groceries" Assets:US:BofA:Checking -2.50 USD Expenses:Food:Groceries @@ -50,6 +60,10 @@ Assets:US:BofA:Checking -3.50 USD Expenses:Food:Coffee +2016-01-07 * "Gas Quick" + Assets:US:BofA:Checking -22.79 USD + Expenses:Auto:Diesel + 2016-01-08 * "Uncle Boons" "Eating out with Joe" Assets:US:BofA:Checking -38.36 USD Expenses:Food:Restaurant @@ -62,13 +76,23 @@ Assets:US:BofA:Checking -6.19 USD Expenses:Food:Coffee +2016-01-10 * "Gas Quick" + Assets:US:BofA:Checking -21.60 USD + Expenses:Auto:Diesel + 2016-01-10 * "Uncle Boons" "Dinner with Mary" Assets:US:BofA:Checking -35.00 USD Expenses:Food:Restaurant +2016-01-11 close Expenses:Auto:Diesel + 2016-01-11 * "Farmer Fresh" "Groceries" Assets:US:BofA:Checking -30.50 USD Expenses:Food:Groceries + +2016-01-12 * "Gas Quick" + Assets:US:BofA:Checking -24.09 USD + Expenses:Auto:Gas """ ) @@ -80,6 +104,7 @@ "Farmer Fresh", "Gimme Coffee", "Uncle Boons", + None, ] ACCOUNT_PREDICTIONS = [ @@ -90,6 +115,7 @@ "Expenses:Food:Groceries", "Expenses:Food:Coffee", "Expenses:Food:Groceries", + "Expenses:Auto:Gas", ] From 8f63d1e2d002385db341a5273ad6dec4d04ae2c4 Mon Sep 17 00:00:00 2001 From: Aaron Lindsay Date: Thu, 19 Nov 2020 14:12:50 -0500 Subject: [PATCH 2/2] tox.ini: Update for beancount move to github --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index edf1730..7109986 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ deps = flake8 pylint pytest - hg+https://bitbucket.org/blais/beancount#egg=beancount + git+https://github.com/beancount/beancount#v2=beancount commands = black --check smart_importer tests flake8 smart_importer tests