Skip to content

Commit

Permalink
feat(predictor): Add support for denylisting accounts (#131)
Browse files Browse the repository at this point in the history
There are many occasions when the training data set may include accounts which
should not be predicted (e.g., accounts used for manual reconciliation of
AR/AP). This feature allows the user to stop the predictor from learning these
accounts, thus preventing contamination of the training set, without having to
maintain a separate filtered copy of their transactions.

Co-authored-by: Jakob Schnitzer <[email protected]>
  • Loading branch information
hlieberman and yagebu authored Sep 15, 2024
1 parent 2354891 commit 1b12989
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
6 changes: 6 additions & 0 deletions smart_importer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class EntryPredictor(ImporterHook):
string_tokenizer: Tokenizer can let smart_importer support more
languages. This parameter should be an callable function with
string parameter and the returning should be a list.
denylist_accounts: Transations with any of these accounts will be
removed from the training data.
"""

# pylint: disable=too-many-instance-attributes
Expand All @@ -54,10 +56,12 @@ def __init__(
predict=True,
overwrite=False,
string_tokenizer: Callable[[str], list] | None = None,
denylist_accounts: list[str] | None = None,
):
super().__init__()
self.training_data = None
self.open_accounts: dict[str, str] = {}
self.denylist_accounts = set(denylist_accounts or [])
self.pipeline: Pipeline | None = None
self.is_fitted = False
self.lock = threading.Lock()
Expand Down Expand Up @@ -133,6 +137,8 @@ def training_data_filter(self, txn):
for pos in txn.postings:
if pos.account not in self.open_accounts:
return False
if pos.account in self.denylist_accounts:
return False
if self.account == pos.account:
found_import_account = True
return found_import_account or not self.account
Expand Down
18 changes: 17 additions & 1 deletion tests/predictors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
2017-01-13 * "Gas Quick"
Assets:US:BofA:Checking -17.45 USD
2017-01-14 * "Axe Throwing with Joe"
Assets:US:BofA:Checking -13.37 USD
"""
)

Expand All @@ -43,6 +46,7 @@
2016-01-01 open Expenses:Auto:Gas USD
2016-01-01 open Expenses:Food:Groceries USD
2016-01-01 open Expenses:Food:Restaurant USD
2016-01-01 open Expenses:Denylisted USD
2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -2.50 USD
Expand Down Expand Up @@ -93,6 +97,11 @@
2016-01-12 * "Gas Quick"
Assets:US:BofA:Checking -24.09 USD
Expenses:Auto:Gas
2016-01-08 * "Axe Throwing with Joe"
Assets:US:BofA:Checking -38.36 USD
Expenses:Denylisted
"""
)

Expand All @@ -105,6 +114,7 @@
"Gimme Coffee",
"Uncle Boons",
None,
None,
]

ACCOUNT_PREDICTIONS = [
Expand All @@ -116,8 +126,11 @@
"Expenses:Food:Coffee",
"Expenses:Food:Groceries",
"Expenses:Auto:Gas",
"Expenses:Food:Groceries",
]

DENYLISTED_ACCOUNTS = ["Expenses:Denylisted"]


class BasicTestImporter(ImporterProtocol):
def extract(self, file, existing_entries=None):
Expand All @@ -133,7 +146,10 @@ def file_account(self, file):


PAYEE_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPayees()])
POSTING_IMPORTER = apply_hooks(BasicTestImporter(), [PredictPostings()])
POSTING_IMPORTER = apply_hooks(
BasicTestImporter(),
[PredictPostings(denylist_accounts=DENYLISTED_ACCOUNTS)],
)


def test_empty_training_data():
Expand Down

0 comments on commit 1b12989

Please sign in to comment.