Skip to content

Commit

Permalink
add some tests for entries.update_postings
Browse files Browse the repository at this point in the history
  • Loading branch information
yagebu committed Sep 15, 2024
1 parent 2259a0e commit f692d96
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 7 deletions.
26 changes: 19 additions & 7 deletions smart_importer/entries.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
"""Helpers to work with Beancount entry objects."""

from __future__ import annotations

from beancount.core.data import Posting, Transaction


def update_postings(transaction, accounts):
"""Update the list of postings of a transaction to match the accounts."""
def update_postings(
transaction: Transaction, accounts: list[str]
) -> Transaction:
"""Update the list of postings of a transaction to match the accounts.
Expects the transaction to be updated to have exactly one posting,
otherwise it is returned unchanged. Adds empty postings for all the
accounts - if the account of the single existing posting is found
in the list of accounts, it is placed there at the first occurence,
otherwise it is appended at the end.
"""

if len(transaction.postings) != 1:
return transaction

posting = transaction.postings[0]

new_postings = [
Posting(account, None, None, None, None, None) for account in accounts
]
for posting in transaction.postings:
if posting.account in accounts:
new_postings[accounts.index(posting.account)] = posting
else:
new_postings.append(posting)
if posting.account in accounts:
new_postings[accounts.index(posting.account)] = posting
else:
new_postings.append(posting)

return transaction._replace(postings=new_postings)

Expand Down
50 changes: 50 additions & 0 deletions tests/entries_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Tests for the entry helpers."""

from __future__ import annotations

# pylint: disable=missing-docstring
from beancount.parser import parser

from smart_importer.entries import update_postings

TEST_DATA, _errors, _options = parser.parse_string(
"""
2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -10.00 USD
2016-01-06 * "Farmer Fresh" "Buying groceries"
Assets:US:BofA:Checking -10.00 USD
Assets:US:BofA:Checking 10.00 USD
"""
)


def test_update_postings() -> None:
txn0 = TEST_DATA[0]

def _update(accounts: list[str]) -> list[tuple[str, bool]]:
"""Update, get accounts and whether this is the original posting."""
updated = update_postings(txn0, accounts)
return [(p.account, p is txn0.postings[0]) for p in updated.postings]

assert _update(["Assets:US:BofA:Checking", "Assets:Other"]) == [
("Assets:US:BofA:Checking", True),
("Assets:Other", False),
]

assert _update(
["Assets:US:BofA:Checking", "Assets:US:BofA:Checking", "Assets:Other"]
) == [
("Assets:US:BofA:Checking", True),
("Assets:US:BofA:Checking", False),
("Assets:Other", False),
]

assert _update(["Assets:Other", "Assets:Other2"]) == [
("Assets:Other", False),
("Assets:Other2", False),
("Assets:US:BofA:Checking", True),
]

txn1 = TEST_DATA[1]
assert update_postings(txn1, ["Assets:Other"]) == txn1

0 comments on commit f692d96

Please sign in to comment.