Skip to content

Commit

Permalink
Fix: Minor Dataframe cleanup (#1700)
Browse files Browse the repository at this point in the history
* minor dataframe fixes

* fix type
  • Loading branch information
eakmanrq authored May 29, 2023
1 parent dcfe67f commit a9e1483
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 14 deletions.
4 changes: 2 additions & 2 deletions sqlglot/dataframe/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _create_cte_from_expression(
sequence_id: t.Optional[str] = None,
**kwargs,
) -> t.Tuple[exp.CTE, str]:
name = self.spark._random_name
name = self._create_hash_from_expression(expression)
expression_to_cte = expression.copy()
expression_to_cte.set("with", None)
cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0]
Expand Down Expand Up @@ -263,7 +263,7 @@ def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) ->
return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]

@classmethod
def _create_hash_from_expression(cls, expression: exp.Select):
def _create_hash_from_expression(cls, expression: exp.Expression) -> str:
value = expression.sql(dialect="spark").encode("utf-8")
return f"t{zlib.crc32(value)}"[:6]

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dataframe/sql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def wrapper(self: DataFrame, *args, **kwargs):
self.last_op = Operation.NO_OP
last_op = self.last_op
new_op = op if op != Operation.NO_OP else last_op
if new_op < last_op or (last_op == new_op and new_op == Operation.SELECT):
if new_op < last_op or (last_op == new_op == Operation.SELECT):
self = self._convert_leaf_to_cte()
df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion sqlglot/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ def find_new_name(taken: t.Collection[str], base: str) -> str:

def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
"""Returns a dictionary created from an object's attributes."""
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
return {
**{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()},
**kwargs,
}


def split_num_words(
Expand Down
8 changes: 6 additions & 2 deletions tests/dataframe/integration/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,8 +1155,9 @@ def test_hint_broadcast_no_alias(self):
df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))

# TODO: Add test to make sure with and without alias are the same once ids are deterministic
self.assertEqual(
"'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1]
)

def test_broadcast_func(self):
df_joined = self.df_spark_employee.join(
Expand Down Expand Up @@ -1188,6 +1189,9 @@ def test_broadcast_func(self):
df, dfs = self.compare_spark_with_sqlglot(df_joined, dfs_joined)
self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(df))
self.assertIn("ResolvedHint (strategy=broadcast)", self.get_explain_plan(dfs))
self.assertEqual(
"'UnresolvedHint BROADCAST, ['a2]", self.get_explain_plan(dfs).split("\n")[1]
)

def test_repartition_by_num(self):
"""
Expand Down
14 changes: 6 additions & 8 deletions tests/dataframe/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,12 @@ def test_typed_schema_nested(self):

@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_select_only(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query)
self.assertIn(
self.assertEqual(
"SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`",
df.sql(pretty=False),
df.sql(pretty=False)[0],
)

@mock.patch("sqlglot.schema", MappingSchema())
Expand All @@ -90,14 +89,13 @@ def test_select_quoted(self):

@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_with_aggs(self):
# TODO: Do exact matches once CTE names are deterministic
query = "SELECT cola, colb FROM table"
sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"})
df = self.spark.sql(query).groupBy(F.col("cola")).agg(F.sum("colb"))
result = df.sql(pretty=False, optimize=False)[0]
self.assertIn("SELECT cola, colb FROM table", result)
self.assertIn("SUM(colb)", result)
self.assertIn("GROUP BY cola", result)
self.assertEqual(
"WITH t38189 AS (SELECT cola, colb FROM table), t42330 AS (SELECT cola, colb FROM t38189) SELECT cola, SUM(colb) FROM t42330 GROUP BY cola",
df.sql(pretty=False, optimize=False)[0],
)

@mock.patch("sqlglot.schema", MappingSchema())
def test_sql_create(self):
Expand Down

0 comments on commit a9e1483

Please sign in to comment.