Skip to content

Commit

Permalink
Added MockBackend.rows("col1", "col2")[(...), (...)] helper (#49)
Browse files Browse the repository at this point in the history
This PR makes testing with `MockBackend` easier.
  • Loading branch information
nfx authored Mar 12, 2024
1 parent f394f4d commit f03a5ab
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/databricks/labs/lsql/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,32 @@ def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance]
rows += stub_rows
return rows

@staticmethod
def rows(*column_names: str):
"""This method is used to create rows for the mock backend."""
number_of_columns = len(column_names)
row_factory = Row.factory(list(column_names))

class MagicFactory:
"""This class is used to create rows for the mock backend."""

def __getitem__(self, tuples: list[tuple | list] | tuple[list | tuple]) -> list[Row]:
if not isinstance(tuples, (list, tuple)):
raise TypeError(f"Expected list or tuple, got {type(tuples)}")
# fix sloppy input
if tuples and not isinstance(tuples[0], (list, tuple)):
tuples = [tuples]
out = []
for record in tuples:
if not isinstance(record, (list, tuple)):
raise TypeError(f"Expected list or tuple, got {type(record)}")
if number_of_columns != len(record):
raise TypeError(f"Expected {number_of_columns} columns, got {len(record)}: {record}")
out.append(row_factory(*record))
return out

return MagicFactory()

@staticmethod
def _row_factory(klass: Dataclass) -> type:
return Row.factory([f.name for f in dataclasses.fields(klass)])
8 changes: 8 additions & 0 deletions src/databricks/labs/lsql/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def as_dict(self) -> dict[str, Any]:
"""Convert the row to a dictionary with the same conventions as Databricks SDK."""
return dict(zip(self.__columns__, self, strict=True))

def __eq__(self, other):
"""Check if the rows are equal."""
if not isinstance(other, Row):
return False
# compare rows as dictionaries, because the order
# of fields in constructor is not guaranteed
return self.as_dict() == other.as_dict()

def __contains__(self, item):
"""Check if the column is in the row."""
return item in self.__columns__
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,14 @@ def test_mock_backend_save_table():
Row(first="aaa", second=True),
Row(first="bbb", second=False),
]


def test_mock_backend_rows_dsl():
rows = MockBackend.rows("foo", "bar")[
[1, 2],
(3, 4),
]
assert rows == [
Row(foo=1, bar=2),
Row(foo=3, bar=4),
]

0 comments on commit f03a5ab

Please sign in to comment.