Skip to content

Commit

Permalink
make dtype comparisons more readable
Browse files Browse the repository at this point in the history
we allow __eq__ between a dtype instance and a dtype class. When comparing with
a class, const and vararg are ignored.
  • Loading branch information
finn-rudolph committed Sep 30, 2024
1 parent ed5fc1b commit 289d57f
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 37 deletions.
4 changes: 1 addition & 3 deletions src/pydiverse/transform/backend/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]):

@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> Cast:
if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance(
cast.target_type, dtypes.Int
):
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.Int:
return sqa.func.trunc(cls.compile_col_expr(cast.val, sqa_col)).cast(
sqa.BigInteger()
)
Expand Down
10 changes: 3 additions & 7 deletions src/pydiverse/transform/backend/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ class MsSqlImpl(SqlImpl):
@classmethod
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)
if isinstance(cast.val.dtype(), dtypes.String) and isinstance(
cast.target_type, dtypes.Float64
):
if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64:
return sqa.case(
(compiled_val == "inf", cls.INF),
(compiled_val == "-inf", -cls.INF),
Expand All @@ -49,9 +47,7 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast:
),
)

if isinstance(cast.val.dtype(), dtypes.Float64) and isinstance(
cast.target_type, dtypes.String
):
if cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String:
compiled = sqa.cast(cls.compile_col_expr(cast.val, sqa_col), sqa.String)
return sqa.case(
(compiled == "1.#QNAN", "nan"),
Expand Down Expand Up @@ -137,7 +133,7 @@ def convert_bool_bit(expr: ColExpr | Order, wants_bool_as_bit: bool) -> ColExpr
)

elif isinstance(expr, Col):
if not wants_bool_as_bit and isinstance(expr.dtype(), dtypes.Bool):
if not wants_bool_as_bit and expr.dtype() == dtypes.Bool:
return ColFn("__eq__", expr, LiteralCol(True))
return expr

Expand Down
6 changes: 2 additions & 4 deletions src/pydiverse/transform/backend/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
return compiled

elif isinstance(expr, LiteralCol):
if isinstance(expr.dtype(), dtypes.String):
if expr.dtype() == dtypes.String:
return pl.lit(expr.val) # polars interprets strings as column names
return expr.val

Expand All @@ -190,9 +190,7 @@ def compile_col_expr(expr: ColExpr, name_in_df: dict[UUID, str]) -> pl.Expr:
pdt_type_to_polars(expr.target_type)
)

if isinstance(expr.val.dtype(), dtypes.Float64) and isinstance(
expr.target_type, dtypes.String
):
if expr.val.dtype() == dtypes.Float64 and expr.target_type == dtypes.String:
compiled = compiled.replace("NaN", "nan")

return compiled
Expand Down
1 change: 0 additions & 1 deletion src/pydiverse/transform/backend/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def export(cls, nd: AstNode, target: Target, final_select: list[Col]) -> Any:
sql_col.name: pdt_type_to_polars(col.dtype())
for sql_col, col in zip(sel.columns.values(), final_select)
},
infer_schema_length=0,
)
df.name = nd.name
return df
Expand Down
12 changes: 3 additions & 9 deletions src/pydiverse/transform/backend/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ class SqliteImpl(SqlImpl):
def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast:
compiled_val = cls.compile_col_expr(cast.val, sqa_col)

if isinstance(cast.val.dtype(), dtypes.String) and isinstance(
cast.target_type, dtypes.Float64
):
if cast.val.dtype() == dtypes.String and cast.target_type == dtypes.Float64:
return sqa.case(
(compiled_val == "inf", cls.INF),
(compiled_val == "-inf", cls.NEG_INF),
Expand All @@ -33,14 +31,10 @@ def compile_cast(cls, cast: Cast, sqa_col: dict[str, sqa.Label]) -> sqa.Cast:
),
)

elif isinstance(cast.val.dtype(), dtypes.DateTime) and isinstance(
cast.target_type, dtypes.Date
):
elif cast.val.dtype() == dtypes.DateTime and cast.target_type == dtypes.Date:
return sqa.type_coerce(sqa.func.date(compiled_val), sqa.Date())

elif isinstance(cast.val.dtype(), dtypes.Float64) and isinstance(
cast.target_type, dtypes.String
):
elif cast.val.dtype() == dtypes.Float64 and cast.target_type == dtypes.String:
return sqa.case(
(compiled_val == cls.INF, "inf"),
(compiled_val == cls.NEG_INF, "-inf"),
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/transform/pipe/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def clean_kwargs(**kwargs) -> dict[str, list[ColExpr]]:


def when(condition: ColExpr) -> WhenClause:
if condition.dtype() is not None and not isinstance(condition.dtype(), dtypes.Bool):
if condition.dtype() is not None and condition.dtype() != dtypes.Bool:
raise TypeError(
"argument for `when` must be of boolean type, but has type "
f"`{condition.dtype()}`"
Expand Down
2 changes: 1 addition & 1 deletion src/pydiverse/transform/pipe/verbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def filter(table: Table, *predicates: ColExpr):
new._ast = Filter(table._ast, preprocess_arg(predicates, table))

for cond in new._ast.filters:
if not isinstance(cond.dtype(), dtypes.Bool):
if cond.dtype() != dtypes.Bool:
raise TypeError(
"predicates given to `filter` must be of boolean type.\n"
f"hint: {cond} is of type {cond.dtype()} instead."
Expand Down
10 changes: 5 additions & 5 deletions src/pydiverse/transform/tree/col_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __str__(self) -> str:
)

def __hash__(self) -> int:
return hash(self.uuid)
return hash(self._uuid)


class ColName(ColExpr):
Expand Down Expand Up @@ -384,7 +384,7 @@ def dtype(self):
raise TypeError(f"invalid case expression: {e}") from e

for cond, _ in self.cases:
if cond.dtype() is not None and not isinstance(cond.dtype(), dtypes.Bool):
if cond.dtype() is not None and cond.dtype() != dtypes.Bool:
raise TypeError(
f"argument `{cond}` for `when` must be of boolean type, but has "
f"type `{cond.dtype()}`"
Expand Down Expand Up @@ -475,9 +475,9 @@ def dtype(self) -> Dtype:
self.target_type.__class__,
) not in valid_casts:
hint = ""
if self.val.dtype() == dtypes.String and (
(self.target_type == dtypes.DateTime)
or (self.target_type == dtypes.Date)
if self.val.dtype() == dtypes.String and self.target_type in (
dtypes.DateTime,
dtypes.Date,
):
hint = (
"\nhint: to convert a str to datetime, call "
Expand Down
15 changes: 9 additions & 6 deletions src/pydiverse/transform/tree/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@ def __init__(self, *, const: bool = False, vararg: bool = False):
self.const = const
self.vararg = vararg

def __eq__(self, other):
if type(self) is other:
def __eq__(self, rhs):
if type(self) is rhs:
return True
if type(self) is not type(other):
if type(self) is not type(rhs):
return False
if self.const != other.const:
if self.const != rhs.const:
return False
if self.vararg != other.vararg:
if self.vararg != rhs.vararg:
return False
if self.name != other.name:
if self.name != rhs.name:
return False
return True

def __ne__(self, rhs: object) -> bool:
return not self.__eq__(rhs)

def __hash__(self):
return hash((self.name, self.const, self.vararg, type(self).__qualname__))

Expand Down

0 comments on commit 289d57f

Please sign in to comment.