Skip to content

Commit

Permalink
Merge pull request #15 from cipherstash/cip-1051-remove-repetitive-ta…
Browse files Browse the repository at this point in the history
…ble-and-column

Remove explicit table and column in Django model fields
  • Loading branch information
yujiyokoo authored Dec 12, 2024
2 parents 4fc59ff + ada53b9 commit e46ba46
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 36 deletions.
21 changes: 13 additions & 8 deletions examples/django_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,14 @@ class TestSettings:


class Customer(models.Model):
age = EncryptedInt(table="customers", column="age", null=True)
is_citizen = EncryptedBoolean(table="customers", column="is_citizen", null=True)
start_date = EncryptedDate(table="customers", column="start_date", null=True)
weight = EncryptedFloat(table="customers", column="weight", null=True)
name = EncryptedText(table="customers", column="name", null=True)
extra_info = EncryptedJsonb(table="customers", column="extra_info", null=True)
age = EncryptedInt(null=True)
is_citizen = EncryptedBoolean(null=True)
start_date = EncryptedDate(null=True)
weight = EncryptedFloat(null=True)
name = EncryptedText(null=True)
extra_info = EncryptedJsonb(null=True)

# non-sensitive fields (not encrypted)
visit_count = IntegerField()

class Meta:
Expand All @@ -56,7 +58,8 @@ def __str__(self):
return (
f"Customer(id={self.id}, age={self.age}, is_citizen={self.is_citizen}, "
f"start_date={self.start_date}, weight={self.weight}, "
f"name='{self.name}', extra_info={self.extra_info})"
f"name='{self.name}', extra_info={self.extra_info}, "
f"visit_count={self.visit_count})"
)


Expand Down Expand Up @@ -184,7 +187,9 @@ def query_example_json_contains():

def print_end_message():
print("That's it! Thank you for following along!")
print(f"Please look at the example code ({os.path.basename(__file__)}) itself to see how records are created and queries are run.")
print(
f"Please look at the example code ({os.path.basename(__file__)}) itself to see how records are created and queries are run."
)


step = 0
Expand Down
18 changes: 13 additions & 5 deletions src/eqlpy/eqldjango.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@

class EncryptedValue(models.JSONField):
def __init__(self, *args, **kwargs):
self.table = kwargs.pop("table")
self.column = kwargs.pop("column")
self.eql_table = kwargs.pop("eql_table", None)
self.eql_column = kwargs.pop("eql_column", None)
super().__init__(*args, **kwargs)

def deconstruct(self):
name, path, args, kwargs = super().deconstruct()
kwargs["table"] = self.table
kwargs["column"] = self.column
kwargs["eql_table"] = self.eql_table
kwargs["eql_column"] = self.eql_column
return name, path, args, kwargs

def get_prep_value(self, value):
if value is not None:
dict = {
"k": "pt",
"p": self._to_db_format(value),
"i": {"t": self.table, "c": self.column},
"i": {"t": self.eql_table, "c": self.eql_column},
"v": 1,
"q": None,
}
Expand All @@ -50,6 +50,14 @@ def from_db_value(self, value, expression, connection):
def db_type(self, connection):
return "cs_encrypted_v1"

def contribute_to_class(self, cls, name, **kwargs):
super().contribute_to_class(cls, name, **kwargs)
# if table or column are not set, use cls and name
if (not hasattr(self, "eql_table")) or (getattr(self, "eql_table") is None):
self.eql_table = cls._meta.db_table
if (not hasattr(self, "eql_column")) or (getattr(self, "eql_column") is None):
self.eql_column = name


class EncryptedInt(EncryptedValue):
def _from_db_format(self, value):
Expand Down
36 changes: 20 additions & 16 deletions tests/eqlpy/eqldjango_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,57 +7,55 @@
class EqlDjangoTest(unittest.TestCase):
def assert_common_parts(self, parsed):
self.assertIsNone(parsed["q"])
self.assertEqual(parsed["i"]["t"], "table")
self.assertEqual(parsed["i"]["c"], "column")
self.assertEqual(parsed["v"], 1)

def test_age(self):
col_type = EncryptedInt(table="table", column="column")
def test_encrypted_int(self):
col_type = EncryptedInt()
prep_value = col_type.get_prep_value(-2)
self.assert_common_parts(prep_value)
self.assertEqual("-2", prep_value["p"])
db_value = col_type.from_db_value(prep_value, None, None)
self.assertEqual(-2, db_value)

def test_is_citizen_false(self):
col_type = EncryptedBoolean(table="table", column="column")
def test_encrypted_boolean_false(self):
col_type = EncryptedBoolean()
prep_value = col_type.get_prep_value(False)
self.assert_common_parts(prep_value)
self.assertEqual("false", prep_value["p"])
db_value = col_type.from_db_value(prep_value, None, None)
self.assertEqual(False, db_value)

def test_is_citizen_true(self):
col_type = EncryptedBoolean(table="table", column="column")
def test_encrypted_boolean_true(self):
col_type = EncryptedBoolean()
prep_value = col_type.get_prep_value(True)
self.assert_common_parts(prep_value)
self.assertEqual("true", prep_value["p"])
db_value = col_type.from_db_value(prep_value, None, None)
self.assertEqual(True, db_value)

def test_start_date(self):
col_type = EncryptedDate(table="table", column="column")
def test_encrypted_date(self):
col_type = EncryptedDate()
prep_value = col_type.get_prep_value(date(2024, 11, 17))
self.assert_common_parts(prep_value)
db_value = col_type.from_db_value(prep_value, None, None)
self.assertEqual(date(2024, 11, 17), db_value)

def test_weight(self):
col_type = EncryptedFloat(table="table", column="column")
def test_encrypted_float(self):
col_type = EncryptedFloat()
prep_value = col_type.get_prep_value(-0.01)
self.assert_common_parts(prep_value)
db_value = col_type.from_db_value(prep_value, None, None)
self.assertEqual(-0.01, db_value)

def test_encrypted_text(self):
col_type = EncryptedText(table="table", column="column")
col_type = EncryptedText()
prep_value = col_type.get_prep_value("test string")
self.assert_common_parts(prep_value)
db_value = col_type.from_db_value(prep_value, None, None)
self.assertEqual("test string", db_value)

def test_extra_info(self):
col_type = EncryptedJsonb(table="table", column="column")
def test_encrypted_jsonb(self):
col_type = EncryptedJsonb()
prep_value = col_type.get_prep_value({"key": "value"})
self.assert_common_parts(prep_value)
db_value = col_type.from_db_value(prep_value, None, None)
Expand All @@ -74,5 +72,11 @@ def test_nones(self):
]

for col_type in col_types:
prep_value = col_type(table="table", column="column").get_prep_value(None)
prep_value = col_type().get_prep_value(None)
self.assertIsNone(prep_value)

def test_table_and_column_name(self):
col_type = EncryptedInt(eql_table="some_table", eql_column="some_column")
prep_value = col_type.get_prep_value(0)
self.assertEqual("some_table", prep_value["i"]["t"])
self.assertEqual("some_column", prep_value["i"]["c"])
15 changes: 8 additions & 7 deletions tests/integration/eqldjango_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,14 @@ def test_jsonb_in_group_by(self):


class Customer(models.Model):
# investigate if we can remove table and column
age = EncryptedInt(table="customers", column="age", null=True)
is_citizen = EncryptedBoolean(table="customers", column="is_citizen", null=True)
start_date = EncryptedDate(table="customers", column="start_date", null=True)
weight = EncryptedFloat(table="customers", column="weight", null=True)
name = EncryptedText(table="customers", column="name", null=True)
extra_info = EncryptedJsonb(table="customers", column="extra_info", null=True)
age = EncryptedInt(null=True)
is_citizen = EncryptedBoolean(null=True)
start_date = EncryptedDate(null=True)
weight = EncryptedFloat(null=True)
name = EncryptedText(null=True)
extra_info = EncryptedJsonb(null=True)

# non-sensitive fields (not encrypted)
visit_count = IntegerField()

class Meta:
Expand Down

0 comments on commit e46ba46

Please sign in to comment.