From 3b006db14238e4539964395d63d9d9d5c3843511 Mon Sep 17 00:00:00 2001 From: Yuji Yokoo Date: Thu, 12 Dec 2024 15:54:41 +0900 Subject: [PATCH 1/2] Remove explicit table and column in Django model fields remove comment about removing table and column formatting --- examples/django_examples.py | 21 ++++++++++++------- src/eqlpy/eqldjango.py | 19 ++++++++++++----- .../integration/eqldjango_integration_test.py | 15 ++++++------- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/examples/django_examples.py b/examples/django_examples.py index 4bcebe9..e71b1b4 100644 --- a/examples/django_examples.py +++ b/examples/django_examples.py @@ -40,12 +40,14 @@ class TestSettings: class Customer(models.Model): - age = EncryptedInt(table="customers", column="age", null=True) - is_citizen = EncryptedBoolean(table="customers", column="is_citizen", null=True) - start_date = EncryptedDate(table="customers", column="start_date", null=True) - weight = EncryptedFloat(table="customers", column="weight", null=True) - name = EncryptedText(table="customers", column="name", null=True) - extra_info = EncryptedJsonb(table="customers", column="extra_info", null=True) + age = EncryptedInt(null=True) + is_citizen = EncryptedBoolean(null=True) + start_date = EncryptedDate(null=True) + weight = EncryptedFloat(null=True) + name = EncryptedText(null=True) + extra_info = EncryptedJsonb(null=True) + + # non-sensitive fields (not encrypted) visit_count = IntegerField() class Meta: @@ -56,7 +58,8 @@ def __str__(self): return ( f"Customer(id={self.id}, age={self.age}, is_citizen={self.is_citizen}, " f"start_date={self.start_date}, weight={self.weight}, " - f"name='{self.name}', extra_info={self.extra_info})" + f"name='{self.name}', extra_info={self.extra_info}, " + f"visit_count={self.visit_count})" ) @@ -184,7 +187,9 @@ def query_example_json_contains(): def print_end_message(): print("That's it! Thank you for following along!") - print(f"Please look at the example code ({os.path.basename(__file__)}) itself to see how records are created and queries are run.") + print( + f"Please look at the example code ({os.path.basename(__file__)}) itself to see how records are created and queries are run." + ) step = 0 diff --git a/src/eqlpy/eqldjango.py b/src/eqlpy/eqldjango.py index 2954008..c938ad5 100644 --- a/src/eqlpy/eqldjango.py +++ b/src/eqlpy/eqldjango.py @@ -11,14 +11,15 @@ class EncryptedValue(models.JSONField): def __init__(self, *args, **kwargs): - self.table = kwargs.pop("table") - self.column = kwargs.pop("column") + self.eql_table = kwargs.pop("eql_table", None) + self.eql_column = kwargs.pop("eql_column", None) + print(f"{self.__class__.__name__} constructor called") super().__init__(*args, **kwargs) def deconstruct(self): name, path, args, kwargs = super().deconstruct() - kwargs["table"] = self.table - kwargs["column"] = self.column + kwargs["eql_table"] = self.eql_table + kwargs["eql_column"] = self.eql_column return name, path, args, kwargs def get_prep_value(self, value): @@ -26,7 +27,7 @@ def get_prep_value(self, value): dict = { "k": "pt", "p": self._to_db_format(value), - "i": {"t": self.table, "c": self.column}, + "i": {"t": self.eql_table, "c": self.eql_column}, "v": 1, "q": None, } @@ -50,6 +51,14 @@ def from_db_value(self, value, expression, connection): def db_type(self, connection): return "cs_encrypted_v1" + def contribute_to_class(self, cls, name, **kwargs): + super().contribute_to_class(cls, name, **kwargs) + # if table or column are not set, use cls and name + if (not hasattr(self, "eql_table")) or (getattr(self, "eql_table") is None): + self.eql_table = cls._meta.db_table + if (not hasattr(self, "eql_column")) or (getattr(self, "eql_column") is None): + self.eql_column = name + class EncryptedInt(EncryptedValue): def _from_db_format(self, value): diff --git a/tests/integration/eqldjango_integration_test.py b/tests/integration/eqldjango_integration_test.py index 590b870..8e98e18 100644 --- a/tests/integration/eqldjango_integration_test.py +++ b/tests/integration/eqldjango_integration_test.py @@ -320,13 +320,14 @@ def test_jsonb_in_group_by(self): class Customer(models.Model): - # investigate if we can remove table and column - age = EncryptedInt(table="customers", column="age", null=True) - is_citizen = EncryptedBoolean(table="customers", column="is_citizen", null=True) - start_date = EncryptedDate(table="customers", column="start_date", null=True) - weight = EncryptedFloat(table="customers", column="weight", null=True) - name = EncryptedText(table="customers", column="name", null=True) - extra_info = EncryptedJsonb(table="customers", column="extra_info", null=True) + age = EncryptedInt(null=True) + is_citizen = EncryptedBoolean(null=True) + start_date = EncryptedDate(null=True) + weight = EncryptedFloat(null=True) + name = EncryptedText(null=True) + extra_info = EncryptedJsonb(null=True) + + # non-sensitive fields (not encrypted) visit_count = IntegerField() class Meta: From ada53b9ae8bedf54fa5cc543764aa4bf719997cb Mon Sep 17 00:00:00 2001 From: Yuji Yokoo Date: Thu, 12 Dec 2024 16:15:22 +0900 Subject: [PATCH 2/2] Update tests and remove comment --- src/eqlpy/eqldjango.py | 1 - tests/eqlpy/eqldjango_test.py | 36 +++++++++++++++++++---------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/eqlpy/eqldjango.py b/src/eqlpy/eqldjango.py index c938ad5..0432151 100644 --- a/src/eqlpy/eqldjango.py +++ b/src/eqlpy/eqldjango.py @@ -13,7 +13,6 @@ class EncryptedValue(models.JSONField): def __init__(self, *args, **kwargs): self.eql_table = kwargs.pop("eql_table", None) self.eql_column = kwargs.pop("eql_column", None) - print(f"{self.__class__.__name__} constructor called") super().__init__(*args, **kwargs) def deconstruct(self): diff --git a/tests/eqlpy/eqldjango_test.py b/tests/eqlpy/eqldjango_test.py index be1952b..55e412a 100644 --- a/tests/eqlpy/eqldjango_test.py +++ b/tests/eqlpy/eqldjango_test.py @@ -7,57 +7,55 @@ class EqlDjangoTest(unittest.TestCase): def assert_common_parts(self, parsed): self.assertIsNone(parsed["q"]) - self.assertEqual(parsed["i"]["t"], "table") - self.assertEqual(parsed["i"]["c"], "column") self.assertEqual(parsed["v"], 1) - def test_age(self): - col_type = EncryptedInt(table="table", column="column") + def test_encrypted_int(self): + col_type = EncryptedInt() prep_value = col_type.get_prep_value(-2) self.assert_common_parts(prep_value) self.assertEqual("-2", prep_value["p"]) db_value = col_type.from_db_value(prep_value, None, None) self.assertEqual(-2, db_value) - def test_is_citizen_false(self): - col_type = EncryptedBoolean(table="table", column="column") + def test_encrypted_boolean_false(self): + col_type = EncryptedBoolean() prep_value = col_type.get_prep_value(False) self.assert_common_parts(prep_value) self.assertEqual("false", prep_value["p"]) db_value = col_type.from_db_value(prep_value, None, None) self.assertEqual(False, db_value) - def test_is_citizen_true(self): - col_type = EncryptedBoolean(table="table", column="column") + def test_encrypted_boolean_true(self): + col_type = EncryptedBoolean() prep_value = col_type.get_prep_value(True) self.assert_common_parts(prep_value) self.assertEqual("true", prep_value["p"]) db_value = col_type.from_db_value(prep_value, None, None) self.assertEqual(True, db_value) - def test_start_date(self): - col_type = EncryptedDate(table="table", column="column") + def test_encrypted_date(self): + col_type = EncryptedDate() prep_value = col_type.get_prep_value(date(2024, 11, 17)) self.assert_common_parts(prep_value) db_value = col_type.from_db_value(prep_value, None, None) self.assertEqual(date(2024, 11, 17), db_value) - def test_weight(self): - col_type = EncryptedFloat(table="table", column="column") + def test_encrypted_float(self): + col_type = EncryptedFloat() prep_value = col_type.get_prep_value(-0.01) self.assert_common_parts(prep_value) db_value = col_type.from_db_value(prep_value, None, None) self.assertEqual(-0.01, db_value) def test_encrypted_text(self): - col_type = EncryptedText(table="table", column="column") + col_type = EncryptedText() prep_value = col_type.get_prep_value("test string") self.assert_common_parts(prep_value) db_value = col_type.from_db_value(prep_value, None, None) self.assertEqual("test string", db_value) - def test_extra_info(self): - col_type = EncryptedJsonb(table="table", column="column") + def test_encrypted_jsonb(self): + col_type = EncryptedJsonb() prep_value = col_type.get_prep_value({"key": "value"}) self.assert_common_parts(prep_value) db_value = col_type.from_db_value(prep_value, None, None) @@ -74,5 +72,11 @@ def test_nones(self): ] for col_type in col_types: - prep_value = col_type(table="table", column="column").get_prep_value(None) + prep_value = col_type().get_prep_value(None) self.assertIsNone(prep_value) + + def test_table_and_column_name(self): + col_type = EncryptedInt(eql_table="some_table", eql_column="some_column") + prep_value = col_type.get_prep_value(0) + self.assertEqual("some_table", prep_value["i"]["t"]) + self.assertEqual("some_column", prep_value["i"]["c"])