diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 8a7e2734b..4d32fd979 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -6,6 +6,7 @@ from debts import settle from sqlalchemy import orm +from sqlalchemy.sql import func from itsdangerous import ( TimedJSONWebSignatureSerializer, URLSafeSerializer, @@ -17,6 +18,9 @@ class Project(db.Model): + class ProjectQuery(BaseQuery): + def get_by_name(self, name): + return Project.query.filter(Project.name == name).one() id = db.Column(db.String(64), primary_key=True) @@ -25,6 +29,8 @@ class Project(db.Model): contact_email = db.Column(db.String(128)) members = db.relationship("Person", backref="project") + query_class = ProjectQuery + @property def _to_serialize(self): obj = { @@ -388,8 +394,11 @@ def _to_serialize(self): def pay_each(self): """Compute what each share has to pay""" if self.owers: - # FIXME: SQL might do that more efficiently - weights = sum(i.weight for i in self.owers) + weights = ( + db.session.query(func.sum(Person.weight)) + .join(billowers, Bill) + .filter(Bill.id == self.id) + ).scalar() return self.amount / weights else: return 0 diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py index 0c99bca83..a12613c16 100644 --- a/ihatemoney/tests/tests.py +++ b/ihatemoney/tests/tests.py @@ -19,6 +19,7 @@ from ihatemoney.manage import GenerateConfig, GeneratePasswordHash, DeleteProject from ihatemoney import models from ihatemoney import utils +from sqlalchemy import orm # Unset configuration file env var if previously set os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None) @@ -2140,5 +2141,68 @@ def test_demo_project_deletion(self): self.assertEqual(len(models.Project.query.all()), 0) +class ModelsTestCase(IhatemoneyTestCase): + def test_bill_pay_each(self): + + self.post_project("raclette") + + # add members + self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2}) + self.client.post("/raclette/members/add", data={"name": "fred"}) + self.client.post("/raclette/members/add", data={"name": "tata"}) + # Add a member with a balance=0 : + self.client.post("/raclette/members/add", data={"name": "toto"}) + + # create bills + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "fromage à raclette", + "payer": 1, + "payed_for": [1, 2, 3], + "amount": "10.0", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "red wine", + "payer": 2, + "payed_for": [1], + "amount": "20", + }, + ) + + self.client.post( + "/raclette/add", + data={ + "date": "2011-08-10", + "what": "delicatessen", + "payer": 1, + "payed_for": [1, 2], + "amount": "10", + }, + ) + + project = models.Project.query.get_by_name(name="raclette") + alexis = models.Person.query.get_by_name(name="alexis", project=project) + alexis_bills = models.Bill.query.options( + orm.subqueryload(models.Bill.owers) + ).filter(models.Bill.owers.contains(alexis)) + for bill in alexis_bills.all(): + if bill.what == "red wine": + pay_each_expected = 20 / 2 + self.assertEqual(bill.pay_each(), pay_each_expected) + if bill.what == "fromage à raclette": + pay_each_expected = 10 / 4 + self.assertEqual(bill.pay_each(), pay_each_expected) + if bill.what == "delicatessen": + pay_each_expected = 10 / 3 + self.assertEqual(bill.pay_each(), pay_each_expected) + + if __name__ == "__main__": unittest.main()