Skip to content

Commit

Permalink
Merge pull request #98 from aclindsa/filter_closed_accounts
Browse files Browse the repository at this point in the history
Filter transactions to only those referencing open accounts
  • Loading branch information
tarioch authored Nov 24, 2020
2 parents bb896bd + 8f63d1e commit 6a45ebf
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 11 deletions.
29 changes: 25 additions & 4 deletions smart_importer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from beancount.core.data import ALL_DIRECTIVES
from beancount.core.data import filter_txns
from beancount.core.data import Transaction
from beancount.core.data import sorted as beancount_sorted
from beancount.core.data import Transaction, Open, Close
from sklearn.pipeline import FeatureUnion
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
Expand Down Expand Up @@ -41,6 +42,7 @@ class EntryPredictor(ImporterHook):
def __init__(self, predict=True, suggest=False, overwrite=False):
super().__init__()
self.training_data = None
self.open_accounts = {}
self.pipeline = None
self.is_fitted = False
self.lock = threading.Lock()
Expand Down Expand Up @@ -69,9 +71,24 @@ def __call__(self, importer, file, imported_entries, existing_entries):
self.train_pipeline()
return self.process_entries(imported_entries)

def load_open_accounts(self, existing_entries):
"""Return map of accounts which have been opened but not closed."""
account_map = {}
if not existing_entries:
return

for entry in beancount_sorted(existing_entries):
if isinstance(entry, Open):
account_map[entry.account] = entry
elif isinstance(entry, Close):
account_map.pop(entry.account)

self.open_accounts = account_map

def load_training_data(self, existing_entries):
"""Load training data, i.e., a list of Beancount entries."""
training_data = existing_entries or []
self.load_open_accounts(existing_entries)
training_data = list(filter_txns(training_data))
length_all = len(training_data)
training_data = [
Expand All @@ -86,9 +103,13 @@ def load_training_data(self, existing_entries):

def training_data_filter(self, txn):
"""Filter function for the training data."""
if self.account:
return any([pos.account == self.account for pos in txn.postings])
return True
found_import_account = False
for pos in txn.postings:
if pos.account not in self.open_accounts:
return False
if self.account == pos.account:
found_import_account = True
return found_import_account or not self.account

@property
def targets(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/data/multiaccounts.beancount
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
Assets:US:EUR -2.50 USD

# TRAINING
2016-01-01 open Assets:US:CHF USD
2016-01-01 open Assets:US:EUR USD
2016-01-01 open Assets:US:USD USD
2016-01-01 open Expenses:Food:Swiss USD
2016-01-01 open Expenses:Food:Europe USD
2016-01-01 open Expenses:Food:Usa USD
2016-01-06 * "Foo"
Assets:US:CHF -2.50 USD
Expenses:Food:Swiss
Expand Down
5 changes: 5 additions & 0 deletions tests/data/simple.beancount
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
Assets:US:BofA:Checking -5.00 USD

# TRAINING
2016-01-01 open Assets:US:BofA:Checking USD
2016-01-01 open Expenses:Food:Coffee USD
2016-01-01 open Expenses:Food:Groceries USD
2016-01-01 open Expenses:Food:Restaurant USD

2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -2.50 USD
Expenses:Food:Groceries
Expand Down
3 changes: 3 additions & 0 deletions tests/data/single-account.beancount
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
Assets:US:BofA:Checking -2.50 USD

# TRAINING
2016-01-01 open Assets:US:BofA:Checking USD
2016-01-01 open Expenses:Food:Groceries USD

2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -2.50 USD
Expenses:Food:Groceries
Expand Down
17 changes: 11 additions & 6 deletions tests/pipelines_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

TEST_DATA, _, __ = parser.parse_string(
"""
2016-01-01 open Assets:US:BofA:Checking USD
2016-01-01 open Expenses:Food:Groceries USD
2016-01-01 open Expenses:Food:Coffee USD
2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -10.00 USD
Expand All @@ -22,11 +26,12 @@
Expenses:Food:Coffee
"""
)
TEST_TRANSACTION = TEST_DATA[0]
TEST_TRANSACTIONS = TEST_DATA[3:]
TEST_TRANSACTION = TEST_TRANSACTIONS[0]


def test_get_payee():
assert AttrGetter("payee").transform(TEST_DATA) == [
assert AttrGetter("payee").transform(TEST_TRANSACTIONS) == [
"Farmer Fresh",
"Starbucks",
"Farmer Fresh",
Expand All @@ -35,7 +40,7 @@ def test_get_payee():


def test_get_narration():
assert AttrGetter("narration").transform(TEST_DATA) == [
assert AttrGetter("narration").transform(TEST_TRANSACTIONS) == [
"Buying groceries",
"Coffee",
"Groceries",
Expand All @@ -44,10 +49,10 @@ def test_get_narration():


def test_get_metadata():
txn = TEST_DATA[0]
txn = TEST_TRANSACTION
txn.meta["attr"] = "value"
assert AttrGetter("meta.attr").transform([txn]) == ["value"]
assert AttrGetter("meta.attr", "default").transform(TEST_DATA) == [
assert AttrGetter("meta.attr", "default").transform(TEST_TRANSACTIONS) == [
"value",
"default",
"default",
Expand All @@ -56,4 +61,4 @@ def test_get_metadata():


def test_get_day_of_month():
assert AttrGetter("date.day").transform(TEST_DATA) == [6, 7, 7, 8]
assert AttrGetter("date.day").transform(TEST_TRANSACTIONS) == [6, 7, 7, 8]
26 changes: 26 additions & 0 deletions tests/predictors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,21 @@
2017-01-12 * "Uncle Boons" ""
Assets:US:BofA:Checking -27.00 USD
2017-01-13 * "Gas Quick"
Assets:US:BofA:Checking -17.45 USD
"""
)

TRAINING_DATA, _, __ = parser.parse_string(
"""
2016-01-01 open Assets:US:BofA:Checking USD
2016-01-01 open Expenses:Food:Coffee USD
2016-01-01 open Expenses:Auto:Diesel USD
2016-01-01 open Expenses:Auto:Gas USD
2016-01-01 open Expenses:Food:Groceries USD
2016-01-01 open Expenses:Food:Restaurant USD
2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -2.50 USD
Expenses:Food:Groceries
Expand All @@ -50,6 +60,10 @@
Assets:US:BofA:Checking -3.50 USD
Expenses:Food:Coffee
2016-01-07 * "Gas Quick"
Assets:US:BofA:Checking -22.79 USD
Expenses:Auto:Diesel
2016-01-08 * "Uncle Boons" "Eating out with Joe"
Assets:US:BofA:Checking -38.36 USD
Expenses:Food:Restaurant
Expand All @@ -62,13 +76,23 @@
Assets:US:BofA:Checking -6.19 USD
Expenses:Food:Coffee
2016-01-10 * "Gas Quick"
Assets:US:BofA:Checking -21.60 USD
Expenses:Auto:Diesel
2016-01-10 * "Uncle Boons" "Dinner with Mary"
Assets:US:BofA:Checking -35.00 USD
Expenses:Food:Restaurant
2016-01-11 close Expenses:Auto:Diesel
2016-01-11 * "Farmer Fresh" "Groceries"
Assets:US:BofA:Checking -30.50 USD
Expenses:Food:Groceries
2016-01-12 * "Gas Quick"
Assets:US:BofA:Checking -24.09 USD
Expenses:Auto:Gas
"""
)

Expand All @@ -80,6 +104,7 @@
"Farmer Fresh",
"Gimme Coffee",
"Uncle Boons",
None,
]

ACCOUNT_PREDICTIONS = [
Expand All @@ -90,6 +115,7 @@
"Expenses:Food:Groceries",
"Expenses:Food:Coffee",
"Expenses:Food:Groceries",
"Expenses:Auto:Gas",
]


Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ deps =
flake8
pylint
pytest
hg+https://bitbucket.org/blais/beancount#egg=beancount
git+https://github.com/beancount/beancount#v2=beancount
commands =
black --check smart_importer tests
flake8 smart_importer tests
Expand Down

0 comments on commit 6a45ebf

Please sign in to comment.