From feecf41bd26f2b80bfc4e582ba6386a2a13f2e29 Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Mon, 24 Oct 2022 13:14:45 -0700 Subject: [PATCH 01/19] Replace `dask_ml.wrappers.ParallelPostFit` with custom `ParallelPostFit` class (#832) * create ParallelPostFit class * _timer * update create_experiment * update comment * migrate changes from 799 * predict_proba_meta * fix gpu? * fix TypeError? * trying again * meta to output_meta * remove _timer * try import sklearn * style fix * Update wrappers.py * use ImportError --- .../physical/rel/custom/create_experiment.py | 14 +- dask_sql/physical/rel/custom/create_model.py | 19 +- dask_sql/physical/rel/custom/wrappers.py | 497 ++++++++++++++++++ docs/source/sql/ml.rst | 2 +- tests/integration/test_model.py | 76 +++ 5 files changed, 589 insertions(+), 19 deletions(-) create mode 100644 dask_sql/physical/rel/custom/wrappers.py diff --git a/dask_sql/physical/rel/custom/create_experiment.py b/dask_sql/physical/rel/custom/create_experiment.py index 642456937..3d510ac18 100644 --- a/dask_sql/physical/rel/custom/create_experiment.py +++ b/dask_sql/physical/rel/custom/create_experiment.py @@ -168,12 +168,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai f"Can not import tuner {experiment_class}. Make sure you spelled it correctly and have installed all packages." ) - try: - from dask_ml.wrappers import ParallelPostFit - except ImportError: # pragma: no cover - raise ValueError( - "dask_ml must be installed to use automl and tune hyperparameters" - ) + from dask_sql.physical.rel.custom.wrappers import ParallelPostFit model = ModelClass() @@ -199,12 +194,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai f"Can not import automl model {automl_class}. Make sure you spelled it correctly and have installed all packages." ) - try: - from dask_ml.wrappers import ParallelPostFit - except ImportError: # pragma: no cover - raise ValueError( - "dask_ml must be installed to use automl and tune hyperparameters" - ) + from dask_sql.physical.rel.custom.wrappers import ParallelPostFit automl = AutoMLClass(**automl_kwargs) # should be avoided if data doesn't fit in memory diff --git a/dask_sql/physical/rel/custom/create_model.py b/dask_sql/physical/rel/custom/create_model.py index 2e6cdeb0a..179dd7971 100644 --- a/dask_sql/physical/rel/custom/create_model.py +++ b/dask_sql/physical/rel/custom/create_model.py @@ -1,6 +1,7 @@ import logging from typing import TYPE_CHECKING +import numpy as np from dask import delayed from dask_sql.datacontainer import DataContainer @@ -43,7 +44,7 @@ class CreateModelPlugin(BaseRelPlugin): unsupervised algorithms). This means, you typically want to set this parameter. * wrap_predict: Boolean flag, whether to wrap the selected - model with a :class:`dask_ml.wrappers.ParallelPostFit`. + model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. Have a look into the [dask-ml docu](https://ml.dask.org/meta-estimators.html#parallel-prediction-and-transformation) to learn more about it. Defaults to false. Typically you set @@ -165,10 +166,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai model = Incremental(estimator=model) if wrap_predict: - try: - from dask_ml.wrappers import ParallelPostFit - except ImportError: # pragma: no cover - raise ValueError("Wrapping requires dask-ml to be installed.") + from dask_sql.physical.rel.custom.wrappers import ParallelPostFit # When `wrap_predict` is set to True we train on single partition frames # because this is only useful for non dask distributed models @@ -183,7 +181,16 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai delayed_model = [delayed(model.fit)(x_p, y_p) for x_p, y_p in zip(X_d, y_d)] model = delayed_model[0].compute() - model = ParallelPostFit(estimator=model) + if "sklearn" in model_class: + output_meta = np.array([]) + model = ParallelPostFit( + estimator=model, + predict_meta=output_meta, + predict_proba_meta=output_meta, + transform_meta=output_meta, + ) + else: + model = ParallelPostFit(estimator=model) else: model.fit(X, y, **fit_kwargs) diff --git a/dask_sql/physical/rel/custom/wrappers.py b/dask_sql/physical/rel/custom/wrappers.py new file mode 100644 index 000000000..7ed0d0dea --- /dev/null +++ b/dask_sql/physical/rel/custom/wrappers.py @@ -0,0 +1,497 @@ +# Copyright 2017, Dask developers +# Dask-ML project - https://github.com/dask/dask-ml +"""Meta-estimators for parallelizing estimators using the scikit-learn API.""" +import logging +import warnings + +import dask.array as da +import dask.dataframe as dd +import dask.delayed +import numpy as np + +try: + import sklearn.base + import sklearn.metrics +except ImportError: # pragma: no cover + raise ImportError("sklearn must be installed") + +logger = logging.getLogger(__name__) + + +class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixin): + """Meta-estimator for parallel predict and transform. + + Parameters + ---------- + estimator : Estimator + The underlying estimator that is fit. + + scoring : string or callable, optional + A single string (see :ref:`scoring_parameter`) or a callable + (see :ref:`scoring`) to evaluate the predictions on the test set. + + For evaluating multiple metrics, either give a list of (unique) + strings or a dict with names as keys and callables as values. + + NOTE that when using custom scorers, each scorer should return a + single value. Metric functions returning a list/array of values + can be wrapped into multiple scorers that return one value each. + + See :ref:`multimetric_grid_search` for an example. + + .. warning:: + + If None, the estimator's default scorer (if available) is used. + Most scikit-learn estimators will convert large Dask arrays to + a single NumPy array, which may exhaust the memory of your worker. + You probably want to always specify `scoring`. + + predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``predict`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + + predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``predict_proba`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + + transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``transform`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + + """ + + class_name = "ParallelPostFit" + + def __init__( + self, + estimator=None, + scoring=None, + predict_meta=None, + predict_proba_meta=None, + transform_meta=None, + ): + self.estimator = estimator + self.scoring = scoring + self.predict_meta = predict_meta + self.predict_proba_meta = predict_proba_meta + self.transform_meta = transform_meta + + def _check_array(self, X): + """Validate an array for post-fit tasks. + + Parameters + ---------- + X : Union[Array, DataFrame] + + Returns + ------- + same type as 'X' + + Notes + ----- + The following checks are applied. + + - Ensure that the array is blocked only along the samples. + """ + if isinstance(X, da.Array): + if X.ndim == 2 and X.numblocks[1] > 1: + logger.debug("auto-rechunking 'X'") + if not np.isnan(X.chunks[0]).any(): + X = X.rechunk({0: "auto", 1: -1}) + else: + X = X.rechunk({1: -1}) + return X + + @property + def _postfit_estimator(self): + # The estimator instance to use for postfit tasks like score + return self.estimator + + def fit(self, X, y=None, **kwargs): + """Fit the underlying estimator. + + Parameters + ---------- + X, y : array-like + **kwargs + Additional fit-kwargs for the underlying estimator. + + Returns + ------- + self : object + """ + logger.info("Starting fit") + result = self.estimator.fit(X, y, **kwargs) + + # Copy over learned attributes + copy_learned_attributes(result, self) + copy_learned_attributes(result, self.estimator) + return self + + def partial_fit(self, X, y=None, **kwargs): + logger.info("Starting partial_fit") + result = self.estimator.partial_fit(X, y, **kwargs) + + # Copy over learned attributes + copy_learned_attributes(result, self) + copy_learned_attributes(result, self.estimator) + return self + + def transform(self, X): + """Transform block or partition-wise for dask inputs. + + For dask inputs, a dask array or dataframe is returned. For other + inputs (NumPy array, pandas dataframe, scipy sparse matrix), the + regular return value is returned. + + If the underlying estimator does not have a ``transform`` method, then + an ``AttributeError`` is raised. + + Parameters + ---------- + X : array-like + + Returns + ------- + transformed : array-like + """ + self._check_method("transform") + X = self._check_array(X) + output_meta = self.transform_meta + + if isinstance(X, da.Array): + if output_meta is None: + output_meta = _get_output_dask_ar_meta_for_estimator( + _transform, self._postfit_estimator, X + ) + return X.map_blocks( + _transform, estimator=self._postfit_estimator, meta=output_meta + ) + elif isinstance(X, dd._Frame): + if output_meta is None: + output_meta = _transform(X._meta_nonempty, self._postfit_estimator) + try: + return X.map_partitions( + _transform, + self._postfit_estimator, + output_meta, + meta=output_meta, + ) + except ValueError: + if output_meta is None: + # dask-dataframe relies on dd.core.no_default + # for infering meta + output_meta = dd.core.no_default + return X.map_partitions( + _transform, estimator=self._postfit_estimator, meta=output_meta + ) + else: + return _transform(X, estimator=self._postfit_estimator) + + def score(self, X, y, compute=True): + """Returns the score on the given data. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Input data, where n_samples is the number of samples and + n_features is the number of features. + + y : array-like, shape = [n_samples] or [n_samples, n_output], optional + Target relative to X for classification or regression; + None for unsupervised learning. + + Returns + ------- + score : float + return self.estimator.score(X, y) + """ + scoring = self.scoring + X = self._check_array(X) + y = self._check_array(y) + + if not scoring: + if type(self._postfit_estimator).score == sklearn.base.RegressorMixin.score: + scoring = "r2" + elif ( + type(self._postfit_estimator).score + == sklearn.base.ClassifierMixin.score + ): + scoring = "accuracy" + else: + scoring = self.scoring + + if scoring: + if not dask.is_dask_collection(X) and not dask.is_dask_collection(y): + scorer = sklearn.metrics.get_scorer(scoring) + else: + # TODO: implement Dask-ML's get_scorer() function + # scorer = get_scorer(scoring, compute=compute) + raise NotImplementedError("get_scorer function not implemented") + return scorer(self, X, y) + else: + return self._postfit_estimator.score(X, y) + + def predict(self, X): + """Predict for X. + + For dask inputs, a dask array or dataframe is returned. For other + inputs (NumPy array, pandas dataframe, scipy sparse matrix), the + regular return value is returned. + + Parameters + ---------- + X : array-like + + Returns + ------- + y : array-like + """ + self._check_method("predict") + X = self._check_array(X) + output_meta = self.predict_meta + + if isinstance(X, da.Array): + if output_meta is None: + output_meta = _get_output_dask_ar_meta_for_estimator( + _predict, self._postfit_estimator, X + ) + + result = X.map_blocks( + _predict, + estimator=self._postfit_estimator, + drop_axis=1, + meta=output_meta, + ) + return result + + elif isinstance(X, dd._Frame): + if output_meta is None: + # dask-dataframe relies on dd.core.no_default + # for infering meta + output_meta = _predict(X._meta_nonempty, self._postfit_estimator) + try: + return X.map_partitions( + _predict, + self._postfit_estimator, + output_meta, + meta=output_meta, + ) + except ValueError: + if output_meta is None: + output_meta = dd.core.no_default + return X.map_partitions( + _predict, estimator=self._postfit_estimator, meta=output_meta + ) + else: + return _predict(X, estimator=self._postfit_estimator) + + def predict_proba(self, X): + """Probability estimates. + + For dask inputs, a dask array or dataframe is returned. For other + inputs (NumPy array, pandas dataframe, scipy sparse matrix), the + regular return value is returned. + + If the underlying estimator does not have a ``predict_proba`` + method, then an ``AttributeError`` is raised. + + Parameters + ---------- + X : array or dataframe + + Returns + ------- + y : array-like + """ + X = self._check_array(X) + + self._check_method("predict_proba") + + output_meta = self.predict_proba_meta + + if isinstance(X, da.Array): + if output_meta is None: + output_meta = _get_output_dask_ar_meta_for_estimator( + _predict_proba, self._postfit_estimator, X + ) + # XXX: multiclass + return X.map_blocks( + _predict_proba, + estimator=self._postfit_estimator, + meta=output_meta, + chunks=(X.chunks[0], len(self._postfit_estimator.classes_)), + ) + elif isinstance(X, dd._Frame): + if output_meta is None: + # dask-dataframe relies on dd.core.no_default + # for infering meta + output_meta = _predict_proba(X._meta_nonempty, self._postfit_estimator) + try: + return X.map_partitions( + _predict_proba, + self._postfit_estimator, + output_meta, + meta=output_meta, + ) + except ValueError: + if output_meta is None: + output_meta = dd.core.no_default + return X.map_partitions( + _predict_proba, estimator=self._postfit_estimator, meta=output_meta + ) + else: + return _predict_proba(X, estimator=self._postfit_estimator) + + def predict_log_proba(self, X): + """Log of probability estimates. + + For dask inputs, a dask array or dataframe is returned. For other + inputs (NumPy array, pandas dataframe, scipy sparse matrix), the + regular return value is returned. + + If the underlying estimator does not have a ``predict_proba`` + method, then an ``AttributeError`` is raised. + + Parameters + ---------- + X : array or dataframe + + Returns + ------- + y : array-like + """ + self._check_method("predict_log_proba") + return da.log(self.predict_proba(X)) + + def _check_method(self, method): + """Check if self.estimator has 'method'. + + Raises + ------ + AttributeError + """ + estimator = self._postfit_estimator + if not hasattr(estimator, method): + msg = "The wrapped estimator '{}' does not have a '{}' method.".format( + estimator, method + ) + raise AttributeError(msg) + return getattr(estimator, method) + + +def _predict(part, estimator, output_meta=None): + if part.shape[0] == 0 and output_meta is not None: + empty_output = handle_empty_partitions(output_meta) + if empty_output is not None: + return empty_output + return estimator.predict(part) + + +def _predict_proba(part, estimator, output_meta=None): + if part.shape[0] == 0 and output_meta is not None: + empty_output = handle_empty_partitions(output_meta) + if empty_output is not None: + return empty_output + return estimator.predict_proba(part) + + +def _transform(part, estimator, output_meta=None): + if part.shape[0] == 0 and output_meta is not None: + empty_output = handle_empty_partitions(output_meta) + if empty_output is not None: + return empty_output + return estimator.transform(part) + + +def handle_empty_partitions(output_meta): + if hasattr(output_meta, "__array_function__"): + if len(output_meta.shape) == 1: + shape = 0 + else: + shape = list(output_meta.shape) + shape[0] = 0 + ar = np.zeros( + shape=shape, + dtype=output_meta.dtype, + like=output_meta, + ) + return ar + elif "scipy.sparse" in type(output_meta).__module__: + # sparse matrices don't support + # `like` due to non implemented __array_function__ + # Refer https://github.com/scipy/scipy/issues/10362 + # Note below works for both cupy and scipy sparse matrices + if len(output_meta.shape) == 1: + shape = 0 + else: + shape = list(output_meta.shape) + shape[0] = 0 + ar = type(output_meta)(shape, dtype=output_meta.dtype) + return ar + elif hasattr(output_meta, "iloc"): + return output_meta.iloc[:0, :] + + +def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar): + """ + Returns the output metadata array + for the model function (predict, transform etc) + by running the appropriate function on dummy data + of shape (1, n_features) + + Parameters + ---------- + + model_fun: Model function + _predict, _transform etc + + estimator : Estimator + The underlying estimator that is fit. + + input_dask_ar: The input dask_array + + Returns + ------- + metadata: metadata of output dask array + + """ + # sklearn fails if input array has size size + # It requires at least 1 sample to run successfully + input_meta = input_dask_ar._meta + if hasattr(input_meta, "__array_function__"): + ar = np.zeros( + shape=(1, input_dask_ar.shape[1]), + dtype=input_dask_ar.dtype, + like=input_meta, + ) + elif "scipy.sparse" in type(input_meta).__module__: + # sparse matrices dont support + # `like` due to non implimented __array_function__ + # Refer https://github.com/scipy/scipy/issues/10362 + # Note below works for both cupy and scipy sparse matrices + ar = type(input_meta)((1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype) + else: + func_name = model_fn.__name__.strip("_") + msg = ( + f"Metadata for {func_name} is not provided, so Dask is " + f"running the {func_name} " + "function on a small dataset to guess output metadata. " + "As a result, It is possible that Dask will guess incorrectly." + ) + warnings.warn(msg) + ar = np.zeros(shape=(1, input_dask_ar.shape[1]), dtype=input_dask_ar.dtype) + return model_fn(ar, estimator) + + +def copy_learned_attributes(from_estimator, to_estimator): + attrs = {k: v for k, v in vars(from_estimator).items() if k.endswith("_")} + + for k, v in attrs.items(): + setattr(to_estimator, k, v) diff --git a/docs/source/sql/ml.rst b/docs/source/sql/ml.rst index 931cdc5ee..5c3a3b9d1 100644 --- a/docs/source/sql/ml.rst +++ b/docs/source/sql/ml.rst @@ -62,7 +62,7 @@ The key-value parameters control, how and which model is trained: want to set this parameter. * ``wrap_predict``: Boolean flag, whether to wrap the selected - model with a :class:`dask_ml.wrappers.ParallelPostFit`. + model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. Have a look into the `dask-ml docu on ParallelPostFit `_ to learn more about it. Defaults to false. Typically you set diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 3c1bd1a69..ad48e5b44 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -934,3 +934,79 @@ def test_experiment_automl_regressor(c, client, training_df): ), "Best model was not registered" check_trained_model(c, "my_automl_exp2") + + +# TODO - many ML tests fail on clusters without sklearn - can we avoid this? +@skip_if_external_scheduler +def test_predict_with_nullable_types(c): + df = pd.DataFrame( + { + "rough_day_of_year": [0, 1, 2, 3], + "prev_day_inches_rained": [0.0, 1.0, 2.0, 3.0], + "rained": [False, False, False, True], + } + ) + c.create_table("train_set", df) + + model_class = "'sklearn.linear_model.LogisticRegression'" + + c.sql( + f""" + CREATE OR REPLACE MODEL model WITH ( + model_class = {model_class}, + wrap_predict = True, + wrap_fit = False, + target_column = 'rained' + ) AS ( + SELECT * + FROM train_set + ) + """ + ) + + expected = c.sql( + """ + SELECT * FROM PREDICT( + MODEL model, + SELECT * FROM train_set + ) + """ + ) + + df = pd.DataFrame( + { + "rough_day_of_year": pd.Series([0, 1, 2, 3], dtype="Int32"), + "prev_day_inches_rained": pd.Series([0.0, 1.0, 2.0, 3.0], dtype="Float32"), + "rained": pd.Series([False, False, False, True]), + } + ) + c.create_table("train_set", df) + + c.sql( + f""" + CREATE OR REPLACE MODEL model WITH ( + model_class = {model_class}, + wrap_predict = True, + wrap_fit = False, + target_column = 'rained' + ) AS ( + SELECT * + FROM train_set + ) + """ + ) + + result = c.sql( + """ + SELECT * FROM PREDICT( + MODEL model, + SELECT * FROM train_set + ) + """ + ) + + assert_eq( + expected, + result, + check_dtype=False, + ) From 74ef3975154566286076ade8e393eb74dd80a78f Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Tue, 25 Oct 2022 15:11:49 -0400 Subject: [PATCH 02/19] Add `py` to testing environments to resolve pytest 7.2.0 issues (#890) * Add py to CI environments * Update TODO message --- continuous_integration/environment-3.10-dev.yaml | 3 +++ continuous_integration/environment-3.8-dev.yaml | 3 +++ continuous_integration/environment-3.9-dev.yaml | 3 +++ continuous_integration/gpuci/environment.yaml | 3 +++ 4 files changed, 12 insertions(+) diff --git a/continuous_integration/environment-3.10-dev.yaml b/continuous_integration/environment-3.10-dev.yaml index af55d33b5..e86d4e62f 100644 --- a/continuous_integration/environment-3.10-dev.yaml +++ b/continuous_integration/environment-3.10-dev.yaml @@ -17,6 +17,9 @@ dependencies: - pandas>=1.4.0 - pre-commit - prompt_toolkit>=3.0.8 +# TODO: remove once py is added to pytest downstream libraries +# https://github.com/pytest-dev/pytest-xdist/issues/832 +- py - psycopg2 - pyarrow>=6.0.1 - pygments>=2.7.1 diff --git a/continuous_integration/environment-3.8-dev.yaml b/continuous_integration/environment-3.8-dev.yaml index dca95257d..33b6492db 100644 --- a/continuous_integration/environment-3.8-dev.yaml +++ b/continuous_integration/environment-3.8-dev.yaml @@ -18,6 +18,9 @@ dependencies: - pre-commit - prompt_toolkit=3.0.8 - psycopg2 +# TODO: remove once py is added to pytest downstream libraries +# https://github.com/pytest-dev/pytest-xdist/issues/832 +- py - pyarrow=6.0.1 - pygments=2.7.1 - pyhive diff --git a/continuous_integration/environment-3.9-dev.yaml b/continuous_integration/environment-3.9-dev.yaml index 52ec271d3..8a2a2bcb0 100644 --- a/continuous_integration/environment-3.9-dev.yaml +++ b/continuous_integration/environment-3.9-dev.yaml @@ -18,6 +18,9 @@ dependencies: - pre-commit - prompt_toolkit>=3.0.8 - psycopg2 +# TODO: remove once py is added to pytest downstream libraries +# https://github.com/pytest-dev/pytest-xdist/issues/832 +- py - pyarrow>=6.0.1 - pygments>=2.7.1 - pyhive diff --git a/continuous_integration/gpuci/environment.yaml b/continuous_integration/gpuci/environment.yaml index 2e7817cfc..c839083e6 100644 --- a/continuous_integration/gpuci/environment.yaml +++ b/continuous_integration/gpuci/environment.yaml @@ -20,6 +20,9 @@ dependencies: - pandas>=1.4.0 - pre-commit - prompt_toolkit>=3.0.8 +# TODO: remove once py is added to pytest downstream libraries +# https://github.com/pytest-dev/pytest-xdist/issues/832 +- py - psycopg2 - pyarrow>=6.0.1 - pygments>=2.7.1 From aa01db88ae33ab23095f353400a6cb72c2cbac60 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 25 Oct 2022 14:18:21 -0600 Subject: [PATCH 03/19] Use latest DataFusion rev (#889) Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> --- dask_planner/Cargo.lock | 33 +++++++++++++++++---------------- dask_planner/Cargo.toml | 10 +++++----- dask_planner/src/expression.rs | 10 ++++++---- dask_planner/src/sql.rs | 2 +- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock index b6e917abf..b3592afe0 100644 --- a/dask_planner/Cargo.lock +++ b/dask_planner/Cargo.lock @@ -58,9 +58,9 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" -version = "24.0.0" +version = "25.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d68391300d5237f6725f0f869ae7cb65d45fcf8a6d18f6ceecd328fb803bef93" +checksum = "76312eb67808c67341f4234861c4fcd2f9868f55e88fa2186ab3b357a6c5830b" dependencies = [ "ahash 0.8.0", "arrow-array", @@ -86,9 +86,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "24.0.0" +version = "25.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0bb00c5862b5eea683812083c495bef01a9a5149da46ad2f4c0e4aa8800f64d" +checksum = "69dd2c257fa76de0bcc63cabe8c81d34c46ef6fa7651e3e497922c3c9878bd67" dependencies = [ "ahash 0.8.0", "arrow-buffer", @@ -102,18 +102,19 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "24.0.0" +version = "25.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e594d0fe0026a8bc2459bdc5ac9623e5fb666724a715e0acbc96ba30c5d4cc7" +checksum = "af963e71bdbbf928231d521083ddc8e8068cf5c8d45d4edcfeaf7eb5cdd779a9" dependencies = [ "half", + "num", ] [[package]] name = "arrow-data" -version = "24.0.0" +version = "25.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8500df05060d86fdc53e9b5cb32e51bfeaacc040fdeced3eb99ac0d59200ff45" +checksum = "52554ffff560c366d7210c2621a3cf1dc408f9969a0c7688a3ba0a62248a945d" dependencies = [ "arrow-buffer", "arrow-schema", @@ -123,9 +124,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "24.0.0" +version = "25.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86d1fef01f25e1452c86fa6887f078de8e0aaeeb828370feab205944cfc30e27" +checksum = "1a5518f2bd7775057391f88257627cbb760ba3e1c2f2444a005ba79158624654" [[package]] name = "async-trait" @@ -350,7 +351,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "13.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=d2d8447ed23bab3c9b9fd89abe469d584e84df6b#d2d8447ed23bab3c9b9fd89abe469d584e84df6b" +source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2" dependencies = [ "arrow", "ordered-float", @@ -360,7 +361,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "13.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=d2d8447ed23bab3c9b9fd89abe469d584e84df6b#d2d8447ed23bab3c9b9fd89abe469d584e84df6b" +source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2" dependencies = [ "ahash 0.8.0", "arrow", @@ -372,7 +373,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "13.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=d2d8447ed23bab3c9b9fd89abe469d584e84df6b#d2d8447ed23bab3c9b9fd89abe469d584e84df6b" +source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2" dependencies = [ "arrow", "async-trait", @@ -387,7 +388,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "13.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=d2d8447ed23bab3c9b9fd89abe469d584e84df6b#d2d8447ed23bab3c9b9fd89abe469d584e84df6b" +source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2" dependencies = [ "ahash 0.8.0", "arrow", @@ -411,7 +412,7 @@ dependencies = [ [[package]] name = "datafusion-row" version = "13.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=d2d8447ed23bab3c9b9fd89abe469d584e84df6b#d2d8447ed23bab3c9b9fd89abe469d584e84df6b" +source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2" dependencies = [ "arrow", "datafusion-common", @@ -422,7 +423,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "13.0.0" -source = "git+https://github.com/apache/arrow-datafusion/?rev=d2d8447ed23bab3c9b9fd89abe469d584e84df6b#d2d8447ed23bab3c9b9fd89abe469d584e84df6b" +source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2" dependencies = [ "arrow", "datafusion-common", diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index e3075f64d..88b8921fa 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -9,12 +9,12 @@ edition = "2021" rust-version = "1.62" [dependencies] -arrow = { version = "24.0.0", features = ["prettyprint"] } +arrow = { version = "25.0.0", features = ["prettyprint"] } async-trait = "0.1.58" -datafusion-common = { git = "https://github.com/apache/arrow-datafusion/", rev = "d2d8447ed23bab3c9b9fd89abe469d584e84df6b" } -datafusion-expr = { git = "https://github.com/apache/arrow-datafusion/", rev = "d2d8447ed23bab3c9b9fd89abe469d584e84df6b" } -datafusion-optimizer = { git = "https://github.com/apache/arrow-datafusion/", rev = "d2d8447ed23bab3c9b9fd89abe469d584e84df6b" } -datafusion-sql = { git = "https://github.com/apache/arrow-datafusion/", rev = "d2d8447ed23bab3c9b9fd89abe469d584e84df6b" } +datafusion-common = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" } +datafusion-expr = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" } +datafusion-optimizer = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" } +datafusion-sql = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" } env_logger = "0.9" log = "^0.4" mimalloc = { version = "*", default-features = false } diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index cc5994385..f69f120ed 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -3,13 +3,14 @@ use std::{convert::From, sync::Arc}; use arrow::datatypes::DataType; use datafusion_common::{Column, DFField, DFSchema, ScalarValue}; use datafusion_expr::{ - expr::BinaryExpr, + expr::{BinaryExpr, Cast}, lit, utils::exprlist_to_fields, Between, BuiltinScalarFunction, Case, Expr, + GetIndexedField, Like, LogicalPlan, Operator, @@ -269,8 +270,8 @@ impl PyExpr { | Expr::IsNotFalse(expr) | Expr::IsNotUnknown(expr) | Expr::Negative(expr) - | Expr::GetIndexedField { expr, .. } - | Expr::Cast { expr, .. } + | Expr::GetIndexedField(GetIndexedField { expr, .. }) + | Expr::Cast(Cast { expr, .. }) | Expr::TryCast { expr, .. } | Expr::Sort { expr, .. } | Expr::InSubquery { expr, .. } => { @@ -485,6 +486,7 @@ impl PyExpr { ScalarValue::IntervalMonthDayNano(..) => "IntervalMonthDayNano", ScalarValue::List(..) => "List", ScalarValue::Struct(..) => "Struct", + ScalarValue::FixedSizeBinary(_, _) => "FixedSizeBinary", }, Expr::ScalarFunction { fun, args: _ } => match fun { BuiltinScalarFunction::Abs => "Abs", @@ -496,7 +498,7 @@ impl PyExpr { ))) } }, - Expr::Cast { expr: _, data_type } => match data_type { + Expr::Cast(Cast { expr: _, data_type }) => match data_type { DataType::Null => "NULL", DataType::Boolean => "BOOLEAN", DataType::Int8 | DataType::UInt8 => "TINYINT", diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index ef34c66f9..d9c453d35 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -306,7 +306,7 @@ impl ContextProvider for DaskSQLContext { None } - fn get_variable_type(&self, _: &[String]) -> Option { + fn get_variable_type(&self, _: &[String]) -> Option { unimplemented!("RUST: get_variable_type is not yet implemented for DaskSQLContext") } } From b50eb6e14292241cf6bde518605c1b1529d31464 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Tue, 25 Oct 2022 17:07:41 -0400 Subject: [PATCH 04/19] Pin dask/distributed for release (#891) * Pin dask/distributed for release * Fix pinning * Pin dask/distributed to 2022.9.2 for release * Pin dask/distributed to 2022.10.0 for release --- continuous_integration/environment-3.10-dev.yaml | 2 +- continuous_integration/environment-3.9-dev.yaml | 2 +- continuous_integration/gpuci/environment.yaml | 2 +- continuous_integration/recipe/meta.yaml | 2 +- docker/conda.txt | 2 +- docs/environment.yml | 2 +- docs/requirements-docs.txt | 2 +- setup.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/continuous_integration/environment-3.10-dev.yaml b/continuous_integration/environment-3.10-dev.yaml index e86d4e62f..a0ccd9dc5 100644 --- a/continuous_integration/environment-3.10-dev.yaml +++ b/continuous_integration/environment-3.10-dev.yaml @@ -4,7 +4,7 @@ channels: - nodefaults dependencies: - dask-ml>=2022.1.22 -- dask>=2022.3.0 +- dask>=2022.3.0,<=2022.10.0 - fastapi>=0.69.0 - fugue>=0.7.0 - intake>=0.6.0 diff --git a/continuous_integration/environment-3.9-dev.yaml b/continuous_integration/environment-3.9-dev.yaml index 8a2a2bcb0..4387d6804 100644 --- a/continuous_integration/environment-3.9-dev.yaml +++ b/continuous_integration/environment-3.9-dev.yaml @@ -4,7 +4,7 @@ channels: - nodefaults dependencies: - dask-ml>=2022.1.22 -- dask>=2022.3.0 +- dask>=2022.3.0,<=2022.10.0 - fastapi>=0.69.0 - fugue>=0.7.0 - intake>=0.6.0 diff --git a/continuous_integration/gpuci/environment.yaml b/continuous_integration/gpuci/environment.yaml index c839083e6..cde7b3472 100644 --- a/continuous_integration/gpuci/environment.yaml +++ b/continuous_integration/gpuci/environment.yaml @@ -7,7 +7,7 @@ channels: - nodefaults dependencies: - dask-ml>=2022.1.22 -- dask>=2022.3.0 +- dask>=2022.3.0,<=2022.10.0 - fastapi>=0.69.0 - fugue>=0.7.0 - intake>=0.6.0 diff --git a/continuous_integration/recipe/meta.yaml b/continuous_integration/recipe/meta.yaml index ab5274076..3725bd24c 100644 --- a/continuous_integration/recipe/meta.yaml +++ b/continuous_integration/recipe/meta.yaml @@ -30,7 +30,7 @@ requirements: - setuptools-rust >=1.4.1 run: - python - - dask >=2022.3.0 + - dask >=2022.3.0,<=2022.10.0 - pandas >=1.4.0 - fastapi >=0.69.0 - uvicorn >=0.13.4 diff --git a/docker/conda.txt b/docker/conda.txt index 32a08c7a9..37ee312ab 100644 --- a/docker/conda.txt +++ b/docker/conda.txt @@ -1,5 +1,5 @@ python>=3.8 -dask>=2022.3.0 +dask>=2022.3.0,<=2022.10.0 pandas>=1.4.0 jpype1>=1.0.2 openjdk>=8 diff --git a/docs/environment.yml b/docs/environment.yml index 5d562c532..941d23496 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -7,7 +7,7 @@ dependencies: - sphinx>=4.0.0 - sphinx-tabs - dask-sphinx-theme>=2.0.3 - - dask>=2022.3.0 + - dask>=2022.3.0,<=2022.10.0 - pandas>=1.4.0 - fugue>=0.7.0 - jpype1>=1.0.2 diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 439516478..3675acd15 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -1,7 +1,7 @@ sphinx>=4.0.0 sphinx-tabs dask-sphinx-theme>=3.0.0 -dask>=2022.3.0 +dask>=2022.3.0,<=2022.10.0 pandas>=1.4.0 fugue>=0.7.0 fastapi>=0.69.0 diff --git a/setup.py b/setup.py index c982e40a0..c48ec15d2 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ python_requires=">=3.8", setup_requires=sphinx_requirements, install_requires=[ - "dask[dataframe,distributed]>=2022.3.0", + "dask[dataframe,distributed]>=2022.3.0,<=2022.10.0", "pandas>=1.4.0", "fastapi>=0.69.0", "uvicorn>=0.13.4", From 2dfa5f711dbc7e5e77c0dc01798ff9f098aefe5b Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Wed, 26 Oct 2022 12:04:13 -0400 Subject: [PATCH 05/19] Unpin dask/distributed for development (#892) --- continuous_integration/environment-3.10-dev.yaml | 2 +- continuous_integration/environment-3.9-dev.yaml | 2 +- continuous_integration/gpuci/environment.yaml | 2 +- continuous_integration/recipe/meta.yaml | 2 +- docker/conda.txt | 2 +- docs/environment.yml | 2 +- docs/requirements-docs.txt | 2 +- setup.py | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/continuous_integration/environment-3.10-dev.yaml b/continuous_integration/environment-3.10-dev.yaml index a0ccd9dc5..e86d4e62f 100644 --- a/continuous_integration/environment-3.10-dev.yaml +++ b/continuous_integration/environment-3.10-dev.yaml @@ -4,7 +4,7 @@ channels: - nodefaults dependencies: - dask-ml>=2022.1.22 -- dask>=2022.3.0,<=2022.10.0 +- dask>=2022.3.0 - fastapi>=0.69.0 - fugue>=0.7.0 - intake>=0.6.0 diff --git a/continuous_integration/environment-3.9-dev.yaml b/continuous_integration/environment-3.9-dev.yaml index 4387d6804..8a2a2bcb0 100644 --- a/continuous_integration/environment-3.9-dev.yaml +++ b/continuous_integration/environment-3.9-dev.yaml @@ -4,7 +4,7 @@ channels: - nodefaults dependencies: - dask-ml>=2022.1.22 -- dask>=2022.3.0,<=2022.10.0 +- dask>=2022.3.0 - fastapi>=0.69.0 - fugue>=0.7.0 - intake>=0.6.0 diff --git a/continuous_integration/gpuci/environment.yaml b/continuous_integration/gpuci/environment.yaml index cde7b3472..c839083e6 100644 --- a/continuous_integration/gpuci/environment.yaml +++ b/continuous_integration/gpuci/environment.yaml @@ -7,7 +7,7 @@ channels: - nodefaults dependencies: - dask-ml>=2022.1.22 -- dask>=2022.3.0,<=2022.10.0 +- dask>=2022.3.0 - fastapi>=0.69.0 - fugue>=0.7.0 - intake>=0.6.0 diff --git a/continuous_integration/recipe/meta.yaml b/continuous_integration/recipe/meta.yaml index 3725bd24c..ab5274076 100644 --- a/continuous_integration/recipe/meta.yaml +++ b/continuous_integration/recipe/meta.yaml @@ -30,7 +30,7 @@ requirements: - setuptools-rust >=1.4.1 run: - python - - dask >=2022.3.0,<=2022.10.0 + - dask >=2022.3.0 - pandas >=1.4.0 - fastapi >=0.69.0 - uvicorn >=0.13.4 diff --git a/docker/conda.txt b/docker/conda.txt index 37ee312ab..32a08c7a9 100644 --- a/docker/conda.txt +++ b/docker/conda.txt @@ -1,5 +1,5 @@ python>=3.8 -dask>=2022.3.0,<=2022.10.0 +dask>=2022.3.0 pandas>=1.4.0 jpype1>=1.0.2 openjdk>=8 diff --git a/docs/environment.yml b/docs/environment.yml index 941d23496..5d562c532 100644 --- a/docs/environment.yml +++ b/docs/environment.yml @@ -7,7 +7,7 @@ dependencies: - sphinx>=4.0.0 - sphinx-tabs - dask-sphinx-theme>=2.0.3 - - dask>=2022.3.0,<=2022.10.0 + - dask>=2022.3.0 - pandas>=1.4.0 - fugue>=0.7.0 - jpype1>=1.0.2 diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 3675acd15..439516478 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -1,7 +1,7 @@ sphinx>=4.0.0 sphinx-tabs dask-sphinx-theme>=3.0.0 -dask>=2022.3.0,<=2022.10.0 +dask>=2022.3.0 pandas>=1.4.0 fugue>=0.7.0 fastapi>=0.69.0 diff --git a/setup.py b/setup.py index c48ec15d2..c982e40a0 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ python_requires=">=3.8", setup_requires=sphinx_requirements, install_requires=[ - "dask[dataframe,distributed]>=2022.3.0,<=2022.10.0", + "dask[dataframe,distributed]>=2022.3.0", "pandas>=1.4.0", "fastapi>=0.69.0", "uvicorn>=0.13.4", From 1d6b7375cf7bd038e6d4395d13fd67695ae288be Mon Sep 17 00:00:00 2001 From: ChrisJar Date: Thu, 27 Oct 2022 23:29:24 +0200 Subject: [PATCH 06/19] Add replace operator (#897) * Add replace operator * Add unit tests Co-authored-by: Chris Jarrett --- dask_sql/physical/rex/core/call.py | 14 ++++++++++++++ tests/integration/test_rex.py | 6 +++++- tests/unit/test_call.py | 3 +++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 1903c8fd9..a66b178dc 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -535,6 +535,19 @@ def trim(self, s, search): return strip_call(search) +class ReplaceOperation(Operation): + """The replace operator (replace occurrences of pattern in a string)""" + + def __init__(self): + super().__init__(self.replace) + + def replace(self, s, pat, repl): + if is_frame(s): + s = s.str + + return s.replace(pat, repl) + + class OverlayOperation(Operation): """The overlay operator (replace string according to positions)""" @@ -965,6 +978,7 @@ class RexCallPlugin(BaseRexPlugin): "substr": SubStringOperation(), "substring": SubStringOperation(), "initcap": TensorScalarOperation(lambda x: x.str.title(), lambda x: x.title()), + "replace": ReplaceOperation(), # date/time operations "extract": ExtractOperation(), "localtime": Operation(lambda *args: pd.Timestamp.now()), diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index 655ff69de..b7d455fe3 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -522,7 +522,9 @@ def test_string_functions(c, gpu): SUBSTR(a, 3, 6) AS s, INITCAP(a) AS t, INITCAP(UPPER(a)) AS u, - INITCAP(LOWER(a)) AS v + INITCAP(LOWER(a)) AS v, + REPLACE(a, 'r', 'l') as w, + REPLACE('Another String', 'th', 'b') as x FROM {input_table} """ @@ -555,6 +557,8 @@ def test_string_functions(c, gpu): "t": ["A Normal String"], "u": ["A Normal String"], "v": ["A Normal String"], + "w": ["a nolmal stling"], + "x": ["Anober String"], } ) diff --git a/tests/unit/test_call.py b/tests/unit/test_call.py index 0075c5cb5..05b116af8 100644 --- a/tests/unit/test_call.py +++ b/tests/unit/test_call.py @@ -182,6 +182,9 @@ def test_string_operations(): assert ops_mapping["substring"](a, 2) == " normal string" assert ops_mapping["substring"](a, 2, 2) == " n" assert ops_mapping["initcap"](a) == "A Normal String" + assert ops_mapping["replace"](a, "nor", "") == "a mal string" + assert ops_mapping["replace"](a, "normal", "new") == "a new string" + assert ops_mapping["replace"]("hello", "", "w") == "whwewlwlwow" def test_dates(): From fd68c287e4f8f808137c594c0c783254221fe56d Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Wed, 2 Nov 2022 06:19:29 -0700 Subject: [PATCH 07/19] Replace `variadic` with `exact` where appropriate (#885) * replace variadic with exact where appropriate * create generate_numeric_signatures function * lint, rand/rand_integer, atan2/var_pop * style fix * generate_signatures and test * style fix --- dask_planner/src/sql.rs | 159 +++++++++++++++++++++++++++++++++------- 1 file changed, 132 insertions(+), 27 deletions(-) diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index d9c453d35..5f5baaa34 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -148,32 +148,38 @@ impl ContextProvider for DaskSQLContext { match name { "year" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } - "atan2" | "mod" => { - let sig = Signature::variadic( - vec![DataType::Float64, DataType::Float64], - Volatility::Immutable, - ); + "mod" => { + let sig = generate_numeric_signatures(2); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "cbrt" | "cot" | "degrees" | "radians" | "sign" | "truncate" => { - let sig = Signature::variadic(vec![DataType::Float64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "rand" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Volatile); + let sig = Signature::one_of( + vec![ + TypeSignature::Exact(vec![]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "rand_integer" => { - let sig = Signature::variadic( - vec![DataType::Int64, DataType::Int64], - Volatility::Volatile, + let sig = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64]), + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), + ], + Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); @@ -231,39 +237,28 @@ impl ContextProvider for DaskSQLContext { match name { "every" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "bit_and" | "bit_or" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "single_value" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|input_types| Ok(Arc::new(input_types[0].clone()))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "regr_count" => { - let sig = Signature::variadic( - vec![DataType::Float64, DataType::Float64], - Volatility::Immutable, - ); + let sig = generate_numeric_signatures(2); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "regr_syy" | "regr_sxx" => { - let sig = Signature::variadic( - vec![DataType::Float64, DataType::Float64], - Volatility::Immutable, - ); - let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); - return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); - } - "var_pop" => { - let sig = Signature::variadic(vec![DataType::Float64], Volatility::Immutable); + let sig = generate_numeric_signatures(2); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } @@ -576,3 +571,113 @@ impl PlanVisitor for OptimizablePlanVisitor { Ok(true) } } + +fn generate_numeric_signatures(n: i32) -> Signature { + // Generates all combinations of vectors of length n, + // i.e., the Cartesian product + let datatypes = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + ]; + let mut cartesian_setup = vec![]; + // cartesian_setup = [datatypes, datatypes] when n == 2, etc. + for _ in 0..n { + cartesian_setup.push(datatypes.clone()); + } + + let mut exact_vector = vec![]; + let mut datatypes_iter = cartesian_setup.iter(); + // First pass + if let Some(first_iter) = datatypes_iter.next() { + for datatype in first_iter { + exact_vector.push(vec![datatype.clone()]); + } + } + // Generate list of lists with length n + for iter in datatypes_iter { + let mut outer_temp = vec![]; + for outer_datatype in exact_vector { + for inner_datatype in iter { + let mut inner_temp = outer_datatype.clone(); + inner_temp.push(inner_datatype.clone()); + outer_temp.push(inner_temp); + } + } + exact_vector = outer_temp; + } + + // Create vector of TypeSignatures + let mut one_of_vector = vec![]; + for vector in exact_vector.iter() { + one_of_vector.push(TypeSignature::Exact(vector.clone())); + } + + Signature::one_of(one_of_vector.clone(), Volatility::Immutable) +} + +#[allow(dead_code)] +fn generate_signatures(cartesian_setup: Vec>) -> Signature { + let mut exact_vector = vec![]; + let mut datatypes_iter = cartesian_setup.iter(); + // First pass + if let Some(first_iter) = datatypes_iter.next() { + for datatype in first_iter { + exact_vector.push(vec![datatype.clone()]); + } + } + // Generate the Cartesian product + for iter in datatypes_iter { + let mut outer_temp = vec![]; + for outer_datatype in exact_vector { + for inner_datatype in iter { + let mut inner_temp = outer_datatype.clone(); + inner_temp.push(inner_datatype.clone()); + outer_temp.push(inner_temp); + } + } + exact_vector = outer_temp; + } + + // Create vector of TypeSignatures + let mut one_of_vector = vec![]; + for vector in exact_vector.iter() { + one_of_vector.push(TypeSignature::Exact(vector.clone())); + } + + Signature::one_of(one_of_vector.clone(), Volatility::Immutable) +} + +#[cfg(test)] +mod test { + use arrow::datatypes::DataType; + use datafusion_expr::{Signature, TypeSignature, Volatility}; + + use crate::sql::generate_signatures; + + #[test] + fn test_generate_signatures() { + let sig = generate_signatures(vec![ + vec![DataType::Int64, DataType::Float64], + vec![DataType::Utf8, DataType::Int64], + ]); + let expected = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), + ], + Volatility::Immutable, + ); + assert_eq!(sig, expected); + } +} From 10de5ef592c65adfe6e261c047726a1a53eba791 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 2 Nov 2022 09:48:08 -0400 Subject: [PATCH 08/19] Bump pyo3 from 0.17.2 to 0.17.3 in /dask_planner (#900) Bumps [pyo3](https://github.com/pyo3/pyo3) from 0.17.2 to 0.17.3. - [Release notes](https://github.com/pyo3/pyo3/releases) - [Changelog](https://github.com/PyO3/pyo3/blob/v0.17.3/CHANGELOG.md) - [Commits](https://github.com/pyo3/pyo3/compare/v0.17.2...v0.17.3) --- updated-dependencies: - dependency-name: pyo3 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- dask_planner/Cargo.lock | 20 ++++++++++---------- dask_planner/Cargo.toml | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock index b3592afe0..ec8dfeac3 100644 --- a/dask_planner/Cargo.lock +++ b/dask_planner/Cargo.lock @@ -900,9 +900,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.17.2" +version = "0.17.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "201b6887e5576bf2f945fe65172c1fcbf3fcf285b23e4d71eb171d9736e38d32" +checksum = "268be0c73583c183f2b14052337465768c07726936a260f480f0857cb95ba543" dependencies = [ "cfg-if", "indoc", @@ -917,9 +917,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.17.2" +version = "0.17.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf0708c9ed01692635cbf056e286008e5a2927ab1a5e48cdd3aeb1ba5a6fef47" +checksum = "28fcd1e73f06ec85bf3280c48c67e731d8290ad3d730f8be9dc07946923005c8" dependencies = [ "once_cell", "target-lexicon", @@ -927,9 +927,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.17.2" +version = "0.17.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90352dea4f486932b72ddf776264d293f85b79a1d214de1d023927b41461132d" +checksum = "0f6cb136e222e49115b3c51c32792886defbfb0adead26a688142b346a0b9ffc" dependencies = [ "libc", "pyo3-build-config", @@ -937,9 +937,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.17.2" +version = "0.17.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7eb24b804a2d9e88bfcc480a5a6dd76f006c1e3edaf064e8250423336e2cd79d" +checksum = "94144a1266e236b1c932682136dc35a9dee8d3589728f68130c7c3861ef96b28" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -949,9 +949,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.17.2" +version = "0.17.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f22bb49f6a7348c253d7ac67a6875f2dc65f36c2ae64a82c381d528972bea6d6" +checksum = "c8df9be978a2d2f0cdebabb03206ed73b11314701a5bfe71b0d753b81997777f" dependencies = [ "proc-macro2", "quote", diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml index 88b8921fa..04079c92b 100644 --- a/dask_planner/Cargo.toml +++ b/dask_planner/Cargo.toml @@ -19,7 +19,7 @@ env_logger = "0.9" log = "^0.4" mimalloc = { version = "*", default-features = false } parking_lot = "0.12" -pyo3 = { version = "0.17.2", features = ["extension-module", "abi3", "abi3-py38"] } +pyo3 = { version = "0.17.3", features = ["extension-module", "abi3", "abi3-py38"] } rand = "0.8" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } uuid = { version = "1.2", features = ["v4"] } From 9bb37a740c4cd104c074afbdd61b80230dfab94d Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Mon, 7 Nov 2022 09:49:13 -0800 Subject: [PATCH 09/19] Sort + limit topk optimization (initial) (#893) * Rust:Add method to retreive fetch rows during sort * Update sort plugin to use nsmallest/largest if applicalbe * move topk optimization checks to it's own function * Fix check for is_topk_optimizable * Add topk tests * Un-xfail q4 * Add sort topk-nelem-limit as config option * Add check for topk config option to is_topk_optimizable * Add more topk sort tests * use common variable for rel.sort plan * Apply suggestions from code review Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> --- dask_planner/src/sql/logical/sort.rs | 5 ++ dask_sql/physical/rel/logical/sort.py | 9 ++- dask_sql/physical/utils/sort.py | 58 +++++++++++++++++++ dask_sql/sql-schema.yaml | 20 +++++-- dask_sql/sql.yaml | 5 +- tests/integration/test_sort.py | 83 +++++++++++++++++++++++++++ tests/unit/test_queries.py | 1 - 7 files changed, 172 insertions(+), 9 deletions(-) diff --git a/dask_planner/src/sql/logical/sort.rs b/dask_planner/src/sql/logical/sort.rs index 0bdd67e23..06d35a28f 100644 --- a/dask_planner/src/sql/logical/sort.rs +++ b/dask_planner/src/sql/logical/sort.rs @@ -19,6 +19,11 @@ impl PySort { pub fn sort_expressions(&self) -> PyResult> { py_expr_list(&self.sort.input, &self.sort.expr) } + + #[pyo3(name = "getNumRows")] + pub fn get_fetch_val(&self) -> PyResult> { + Ok(self.sort.fetch) + } } impl TryFrom for PySort { diff --git a/dask_sql/physical/rel/logical/sort.py b/dask_sql/physical/rel/logical/sort.py index 9800df978..2e1376d41 100644 --- a/dask_sql/physical/rel/logical/sort.py +++ b/dask_sql/physical/rel/logical/sort.py @@ -20,16 +20,19 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai (dc,) = self.assert_inputs(rel, 1, context) df = dc.df cc = dc.column_container - sort_expressions = rel.sort().getCollation() + sort_plan = rel.sort() + sort_expressions = sort_plan.getCollation() sort_columns = [ cc.get_backend_by_frontend_name(expr.column_name(rel)) for expr in sort_expressions ] sort_ascending = [expr.isSortAscending() for expr in sort_expressions] sort_null_first = [expr.isSortNullsFirst() for expr in sort_expressions] + sort_num_rows = sort_plan.getNumRows() - df = df.persist() - df = apply_sort(df, sort_columns, sort_ascending, sort_null_first) + df = apply_sort( + df, sort_columns, sort_ascending, sort_null_first, sort_num_rows + ) cc = self.fix_column_to_row_type(cc, rel.getRowType()) # No column type has changed, so no need to cast again diff --git a/dask_sql/physical/utils/sort.py b/dask_sql/physical/utils/sort.py index c2ccce3c1..e53134d44 100644 --- a/dask_sql/physical/utils/sort.py +++ b/dask_sql/physical/utils/sort.py @@ -2,6 +2,7 @@ import dask.dataframe as dd import pandas as pd +from dask import config as dask_config from dask.utils import M from dask_sql.utils import make_pickable_without_dask_sql @@ -12,6 +13,7 @@ def apply_sort( sort_columns: List[str], sort_ascending: List[bool], sort_null_first: List[bool], + sort_num_rows: int = None, ) -> dd.DataFrame: # when sort_values doesn't support lists of ascending / null # position booleans, we can still do the sort provided that @@ -19,6 +21,24 @@ def apply_sort( single_ascending = len(set(sort_ascending)) == 1 single_null_first = len(set(sort_null_first)) == 1 + if is_topk_optimizable( + df=df, + sort_columns=sort_columns, + single_ascending=single_ascending, + sort_null_first=sort_null_first, + sort_num_rows=sort_num_rows, + ): + return topk_sort( + df=df, + sort_columns=sort_columns, + sort_ascending=sort_ascending, + sort_num_rows=sort_num_rows, + ) + + else: + # Pre persist before sort to avoid duplicate compute + df = df.persist() + # pandas / cudf don't support lists of null positions if df.npartitions == 1 and single_null_first: return df.map_partitions( @@ -57,6 +77,18 @@ def apply_sort( ).persist() +def topk_sort( + df: dd.DataFrame, + sort_columns: List[str], + sort_ascending: List[bool], + sort_num_rows: int = None, +): + if sort_ascending[0]: + return df.nsmallest(n=sort_num_rows, columns=sort_columns) + else: + return df.nlargest(n=sort_num_rows, columns=sort_columns) + + def sort_partition_func( partition: pd.DataFrame, sort_columns: List[str], @@ -85,3 +117,29 @@ def sort_partition_func( ) return partition + + +def is_topk_optimizable( + df: dd.DataFrame, + sort_columns: List[str], + single_ascending: bool, + sort_null_first: List[bool], + sort_num_rows: int = None, +): + if ( + sort_num_rows is None + or not single_ascending + or any(sort_null_first) + # pandas doesnt support nsmallest/nlargest with object dtypes + or ( + "pandas" in str(df._partition_type) + and any(df[sort_columns].dtypes == "object") + ) + or ( + sort_num_rows * len(df.columns) + > dask_config.get("sql.sort.topk-nelem-limit") + ) + ): + return False + + return True diff --git a/dask_sql/sql-schema.yaml b/dask_sql/sql-schema.yaml index c6d5bd3c0..993bf0031 100644 --- a/dask_sql/sql-schema.yaml +++ b/dask_sql/sql-schema.yaml @@ -40,12 +40,24 @@ properties: queries, but can signicantly reduce memory usage when querying a small subset of a large table. Default is ``true``. - predicate_pushdown: + optimize: type: boolean description: | - Whether to try pushing down filter predicates into IO (when possible). + Whether the first generated logical plan should be further optimized or used as is. - optimize: + predicate_pushdown: type: boolean description: | - Whether the first generated logical plan should be further optimized or used as is. + Whether to try pushing down filter predicates into IO (when possible). + + sort: + type: object + properties: + + topk-nelem-limit: + type: integer + description: | + Total number of elements below which dask-sql should attempt to apply the top-k + optimization (when possible). ``nelem`` is defined as the limit or ``k`` value times the + number of columns. Default is 1000000, corresponding to a LIMIT clause of 1 million in a + 1 column table. diff --git a/dask_sql/sql.yaml b/dask_sql/sql.yaml index 5c175320d..22ff68f70 100644 --- a/dask_sql/sql.yaml +++ b/dask_sql/sql.yaml @@ -9,6 +9,9 @@ sql: limit: check-first-partition: True + optimize: True + predicate_pushdown: True - optimize: True + sort: + topk-nelem-limit: 1000000 diff --git a/tests/integration/test_sort.py b/tests/integration/test_sort.py index 8b1d125a9..1956a3bce 100644 --- a/tests/integration/test_sort.py +++ b/tests/integration/test_sort.py @@ -351,3 +351,86 @@ def test_sort_by_old_alias(c, input_table_1, request): df_expected = user_table_1.sort_values(["b", "user_id"], ascending=[True, False])[ ["b"] ] + + +@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)]) +def test_sort_topk(gpu): + c = Context() + df = pd.DataFrame( + { + "a": [float("nan"), 1] * 30, + "b": [1, 2, 3] * 20, + "c": ["a", "b", "c"] * 20, + } + ) + c.create_table("df", dd.from_pandas(df, npartitions=10), gpu=gpu) + + df_result = c.sql("""SELECT * FROM df ORDER BY a LIMIT 10""") + assert any(["nsmallest" in key for key in df_result.dask.layers.keys()]) + assert_eq( + df_result, + pd.DataFrame( + { + "a": [1.0] * 10, + "b": ([2, 1, 3] * 4)[:10], + "c": (["b", "a", "c"] * 4)[:10], + } + ), + check_index=False, + ) + + df_result = c.sql("""SELECT * FROM df ORDER BY a, b LIMIT 10""") + assert any(["nsmallest" in key for key in df_result.dask.layers.keys()]) + assert_eq( + df_result, + pd.DataFrame({"a": [1.0] * 10, "b": [1] * 10, "c": ["a"] * 10}), + check_index=False, + ) + + df_result = c.sql( + """SELECT * FROM df ORDER BY a DESC NULLS LAST, b DESC NULLS LAST LIMIT 10""" + ) + assert any(["nlargest" in key for key in df_result.dask.layers.keys()]) + assert_eq( + df_result, + pd.DataFrame({"a": [1.0] * 10, "b": [3] * 10, "c": ["c"] * 10}), + check_index=False, + ) + + # String column nlargest/smallest not supported for pandas + df_result = c.sql("""SELECT * FROM df ORDER BY c LIMIT 10""") + if not gpu: + assert all(["nlargest" not in key for key in df_result.dask.layers.keys()]) + assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()]) + else: + assert_eq( + df_result, + pd.DataFrame({"a": [float("nan"), 1] * 5, "b": [1] * 10, "c": ["a"] * 10}), + check_index=False, + ) + + # Assert that the optimization isn't applied when there is any nulls first + df_result = c.sql( + """SELECT * FROM df ORDER BY a DESC, b DESC NULLS LAST LIMIT 10""" + ) + assert all(["nlargest" not in key for key in df_result.dask.layers.keys()]) + assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()]) + + # Assert optimization isn't applied for mixed asc + desc sort + df_result = c.sql("""SELECT * FROM df ORDER BY a, b DESC NULLS LAST LIMIT 10""") + assert all(["nlargest" not in key for key in df_result.dask.layers.keys()]) + assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()]) + + # Assert optimization isn't applied when the number of requested elements + # exceed topk-nelem-limit config value + # Default topk-nelem-limit is 1M and 334k*3columns takes it above this limit + df_result = c.sql("""SELECT * FROM df ORDER BY a, b LIMIT 333334""") + assert all(["nlargest" not in key for key in df_result.dask.layers.keys()]) + assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()]) + + df_result = c.sql( + """SELECT * FROM df ORDER BY a, b LIMIT 10""", + config_options={"sql.sort.topk-nelem-limit": 29}, + ) + assert all(["nlargest" not in key for key in df_result.dask.layers.keys()]) + assert all(["nsmallest" not in key for key in df_result.dask.layers.keys()]) diff --git a/tests/unit/test_queries.py b/tests/unit/test_queries.py index 012f84e2f..f35bd5750 100644 --- a/tests/unit/test_queries.py +++ b/tests/unit/test_queries.py @@ -3,7 +3,6 @@ import pytest XFAIL_QUERIES = ( - 4, 5, 6, 8, From 65b37146bbd498924e2d01e5f5efcf9678722db9 Mon Sep 17 00:00:00 2001 From: Nick Vazquez Date: Mon, 7 Nov 2022 12:22:19 -0800 Subject: [PATCH 10/19] [bug][docs] `my_ds` -> `my_df` (#905) Looks like it is supposed to be `my_df` from the example code further down the page --- docs/source/data_input.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/data_input.rst b/docs/source/data_input.rst index 7bbae23f2..8cb189d6e 100644 --- a/docs/source/data_input.rst +++ b/docs/source/data_input.rst @@ -107,7 +107,7 @@ and then later register it in the :class:`~dask_sql.Context` via SQL: # a dask.distributed Client client = Client(...) - client.publish_dataset(my_ds=df) + client.publish_dataset(my_df=df) Later in SQL: @@ -119,7 +119,7 @@ Later in SQL: CREATE TABLE my_data WITH ( format = 'memory', - location = 'my_ds' + location = 'my_df' ) .. group-tab:: GPU @@ -128,7 +128,7 @@ Later in SQL: CREATE TABLE my_data WITH ( format = 'memory', - location = 'my_ds', + location = 'my_df', gpu = True ) From 9f97cc7413c1013217b89417c9d1be9d4e0226cc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 8 Nov 2022 07:59:36 -0800 Subject: [PATCH 11/19] Bump env_logger from 0.9.1 to 0.9.3 in /dask_planner (#906) Bumps [env_logger](https://github.com/env-logger-rs/env_logger) from 0.9.1 to 0.9.3. - [Release notes](https://github.com/env-logger-rs/env_logger/releases) - [Changelog](https://github.com/env-logger-rs/env_logger/blob/main/CHANGELOG.md) - [Commits](https://github.com/env-logger-rs/env_logger/compare/v0.9.1...v0.9.3) --- updated-dependencies: - dependency-name: env_logger dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- dask_planner/Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock index ec8dfeac3..e4eef7ccf 100644 --- a/dask_planner/Cargo.lock +++ b/dask_planner/Cargo.lock @@ -444,9 +444,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.9.1" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c90bf5f19754d10198ccb95b70664fc925bd1fc090a0fd9a6ebc54acc8cd6272" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" dependencies = [ "atty", "humantime", From 6844c74fc99c78348051119bfa15f3aad53293e6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 9 Nov 2022 09:21:42 -0800 Subject: [PATCH 12/19] Bump mimalloc from 0.1.30 to 0.1.31 in /dask_planner (#910) Bumps [mimalloc](https://github.com/purpleprotocol/mimalloc_rust) from 0.1.30 to 0.1.31. - [Release notes](https://github.com/purpleprotocol/mimalloc_rust/releases) - [Commits](https://github.com/purpleprotocol/mimalloc_rust/compare/v0.1.30...v0.1.31) --- updated-dependencies: - dependency-name: mimalloc dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- dask_planner/Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock index e4eef7ccf..b6586e238 100644 --- a/dask_planner/Cargo.lock +++ b/dask_planner/Cargo.lock @@ -661,9 +661,9 @@ checksum = "292a948cd991e376cf75541fe5b97a1081d713c618b4f1b9500f8844e49eb565" [[package]] name = "libmimalloc-sys" -version = "0.1.26" +version = "0.1.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc093ab289b0bfda3aa1bdfab9c9542be29c7ef385cfcbe77f8c9813588eb48" +checksum = "c37567b180c1af25924b303ddf1ee4467653783440c62360beb2b322a4d93361" dependencies = [ "cc", ] @@ -713,9 +713,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76ce6a4b40d3bff9eb3ce9881ca0737a85072f9f975886082640cd46a75cdb35" +checksum = "b32d6a9ac92d0239d7bfa31137fb47634ac7272a3c11bcee91379ac100781670" dependencies = [ "libmimalloc-sys", ] From 5440effe0d26c746650c4e801c7f846efd17f5c9 Mon Sep 17 00:00:00 2001 From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com> Date: Mon, 14 Nov 2022 14:23:54 -0800 Subject: [PATCH 13/19] Replace `dask_ml.wrappers.Incremental` with custom `Incremental` class (#855) * Create metrics.py * add incremental functionality * lint and some comments * update more comments * add dask-ml fit function * style fix * DASK_2022_01_0 * add unit tests * style fix * remove scheduler * experiment_class comment * apply Vibhu's suggestions * style fix --- .../physical/rel/custom/create_experiment.py | 16 +- dask_sql/physical/rel/custom/create_model.py | 33 +- dask_sql/physical/rel/custom/metrics.py | 206 ++++++++++++ dask_sql/physical/rel/custom/wrappers.py | 317 +++++++++++++++++- docs/source/sql/ml.rst | 28 +- tests/unit/test_ml_wrappers.py | 250 ++++++++++++++ 6 files changed, 797 insertions(+), 53 deletions(-) create mode 100644 dask_sql/physical/rel/custom/metrics.py create mode 100644 tests/unit/test_ml_wrappers.py diff --git a/dask_sql/physical/rel/custom/create_experiment.py b/dask_sql/physical/rel/custom/create_experiment.py index 3d510ac18..ddec9fccf 100644 --- a/dask_sql/physical/rel/custom/create_experiment.py +++ b/dask_sql/physical/rel/custom/create_experiment.py @@ -30,17 +30,9 @@ class CreateExperimentPlugin(BaseRelPlugin): * model_class: Full path to the class of the model which has to be tuned. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. - Have a look into the - [dask-ml documentation](https://ml.dask.org/index.html) - for more information on which models work best. You might need to install necessary packages to use the models. * experiment_class : Full path of the Hyperparameter tuner - from dask_ml, choose dask tuner class carefully based on what you - exactly need (memory vs compute constrains), refer: - [dask-ml documentation](https://ml.dask.org/hyper-parameter-search.html) - (for tuning hyperparameter of the models both model_class and experiment class are - required parameters.) * tune_parameters: Key-value of pairs of Hyperparameters to tune, i.e Search Space for particular model to tune @@ -64,7 +56,7 @@ class CreateExperimentPlugin(BaseRelPlugin): CREATE EXPERIMENT my_exp WITH( model_class = 'sklearn.ensemble.GradientBoostingClassifier', - experiment_class = 'dask_ml.model_selection.GridSearchCV', + experiment_class = 'sklearn.model_selection.GridSearchCV', tune_parameters = (n_estimators = ARRAY [16, 32, 2], learning_rate = ARRAY [0.1,0.01,0.001], max_depth = ARRAY [3,4,5,10] @@ -174,7 +166,11 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai search = ExperimentClass(model, {**parameters}, **experiment_kwargs) logger.info(tune_fit_kwargs) - search.fit(X, y, **tune_fit_kwargs) + search.fit( + X.to_dask_array(lengths=True), + y.to_dask_array(lengths=True), + **tune_fit_kwargs, + ) df = pd.DataFrame(search.cv_results_) df["model_class"] = model_class diff --git a/dask_sql/physical/rel/custom/create_model.py b/dask_sql/physical/rel/custom/create_model.py index 179dd7971..726568c5d 100644 --- a/dask_sql/physical/rel/custom/create_model.py +++ b/dask_sql/physical/rel/custom/create_model.py @@ -32,9 +32,6 @@ class CreateModelPlugin(BaseRelPlugin): * model_class: Full path to the class of the model to train. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. - Have a look into the - [dask-ml documentation](https://ml.dask.org/index.html) - for more information on which models work best. You might need to install necessary packages to use the models. * target_column: Which column from the data to use as target. @@ -45,16 +42,12 @@ class CreateModelPlugin(BaseRelPlugin): want to set this parameter. * wrap_predict: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. - Have a look into the - [dask-ml docu](https://ml.dask.org/meta-estimators.html#parallel-prediction-and-transformation) - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if predicting on big data. + Defaults to false. Typically you set it to true for + sklearn models if predicting on big data. * wrap_fit: Boolean flag, whether to wrap the selected - model with a :class:`dask_ml.wrappers.Incremental`. - Have a look into the - [dask-ml docu](https://ml.dask.org/incremental.html) - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if training on big data. + model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`. + Defaults to false. Typically you set it to true for + sklearn models if training on big data. * fit_kwargs: keyword arguments sent to the call to fit(). All other arguments are passed to the constructor of the @@ -76,7 +69,7 @@ class CreateModelPlugin(BaseRelPlugin): Examples: CREATE MODEL my_model WITH ( - model_class = 'dask_ml.xgboost.XGBClassifier', + model_class = 'xgboost.XGBClassifier', target_column = 'target' ) AS ( SELECT x, y, target @@ -95,11 +88,10 @@ class CreateModelPlugin(BaseRelPlugin): dask dataframes. * if you are training on relatively small amounts - of data but predicting on large data samples - (and you are not using a model build for usage with dask - from the dask-ml package), you might want to set - `wrap_predict` to True. With this option, - model interference will be parallelized/distributed. + of data but predicting on large data samples, + you might want to set `wrap_predict` to True. + With this option, model interference will be + parallelized/distributed. * If you are training on large amounts of data, you can try setting wrap_fit to True. This will do the same on the training step, but works only on @@ -158,10 +150,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai model = ModelClass(**kwargs) if wrap_fit: - try: - from dask_ml.wrappers import Incremental - except ImportError: # pragma: no cover - raise ValueError("Wrapping requires dask-ml to be installed.") + from dask_sql.physical.rel.custom.wrappers import Incremental model = Incremental(estimator=model) diff --git a/dask_sql/physical/rel/custom/metrics.py b/dask_sql/physical/rel/custom/metrics.py new file mode 100644 index 000000000..4b898d1a9 --- /dev/null +++ b/dask_sql/physical/rel/custom/metrics.py @@ -0,0 +1,206 @@ +# Copyright 2017, Dask developers +# Dask-ML project - https://github.com/dask/dask-ml +from typing import Optional, TypeVar + +import dask +import dask.array as da +import numpy as np +import sklearn.metrics +import sklearn.utils.multiclass +from dask.array import Array +from dask.utils import derived_from + +ArrayLike = TypeVar("ArrayLike", Array, np.ndarray) + + +def accuracy_score( + y_true: ArrayLike, + y_pred: ArrayLike, + normalize: bool = True, + sample_weight: Optional[ArrayLike] = None, + compute: bool = True, +) -> ArrayLike: + """Accuracy classification score. + In multilabel classification, this function computes subset accuracy: + the set of labels predicted for a sample must *exactly* match the + corresponding set of labels in y_true. + Read more in the :ref:`User Guide `. + Parameters + ---------- + y_true : 1d array-like, or label indicator array + Ground truth (correct) labels. + y_pred : 1d array-like, or label indicator array + Predicted labels, as returned by a classifier. + normalize : bool, optional (default=True) + If ``False``, return the number of correctly classified samples. + Otherwise, return the fraction of correctly classified samples. + sample_weight : 1d array-like, optional + Sample weights. + .. versionadded:: 0.7.0 + Returns + ------- + score : scalar dask Array + If ``normalize == True``, return the correctly classified samples + (float), else it returns the number of correctly classified samples + (int). + The best performance is 1 with ``normalize == True`` and the number + of samples with ``normalize == False``. + Notes + ----- + In binary and multiclass classification, this function is equal + to the ``jaccard_similarity_score`` function. + + """ + + if y_true.ndim > 1: + differing_labels = ((y_true - y_pred) == 0).all(1) + score = differing_labels != 0 + else: + score = y_true == y_pred + + if normalize: + score = da.average(score, weights=sample_weight) + elif sample_weight is not None: + score = da.dot(score, sample_weight) + else: + score = score.sum() + + if compute: + score = score.compute() + return score + + +def _log_loss_inner( + x: ArrayLike, y: ArrayLike, sample_weight: Optional[ArrayLike], **kwargs +): + # da.map_blocks wasn't able to concatenate together the results + # when we reduce down to a scalar per block. So we make an + # array with 1 element. + if sample_weight is not None: + sample_weight = sample_weight.ravel() + return np.array( + [sklearn.metrics.log_loss(x, y, sample_weight=sample_weight, **kwargs)] + ) + + +def log_loss( + y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None, labels=None +): + if not (dask.is_dask_collection(y_true) and dask.is_dask_collection(y_pred)): + return sklearn.metrics.log_loss( + y_true, + y_pred, + eps=eps, + normalize=normalize, + sample_weight=sample_weight, + labels=labels, + ) + + if y_pred.ndim > 1 and y_true.ndim == 1: + y_true = y_true.reshape(-1, 1) + drop_axis: Optional[int] = 1 + if sample_weight is not None: + sample_weight = sample_weight.reshape(-1, 1) + else: + drop_axis = None + + result = da.map_blocks( + _log_loss_inner, + y_true, + y_pred, + sample_weight, + chunks=(1,), + drop_axis=drop_axis, + dtype="f8", + eps=eps, + normalize=normalize, + labels=labels, + ) + if normalize and sample_weight is not None: + sample_weight = sample_weight.ravel() + block_weights = sample_weight.map_blocks(np.sum, chunks=(1,), keepdims=True) + return da.average(result, 0, weights=block_weights) + elif normalize: + return result.mean() + else: + return result.sum() + + +def _check_sample_weight(sample_weight: Optional[ArrayLike]): + if sample_weight is not None: + raise ValueError("'sample_weight' is not supported.") + + +@derived_from(sklearn.metrics) +def mean_squared_error( + y_true: ArrayLike, + y_pred: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + multioutput: Optional[str] = "uniform_average", + squared: bool = True, + compute: bool = True, +) -> ArrayLike: + _check_sample_weight(sample_weight) + output_errors = ((y_pred - y_true) ** 2).mean(axis=0) + + if isinstance(multioutput, str) or multioutput is None: + if multioutput == "raw_values": + if compute: + return output_errors.compute() + else: + return output_errors + else: + raise ValueError("Weighted 'multioutput' not supported.") + result = output_errors.mean() + if not squared: + result = da.sqrt(result) + if compute: + result = result.compute() + return result + + +def _check_reg_targets( + y_true: ArrayLike, y_pred: ArrayLike, multioutput: Optional[str] +): + if multioutput is not None and multioutput != "uniform_average": + raise NotImplementedError("'multioutput' must be 'uniform_average'") + + if y_true.ndim == 1: + y_true = y_true.reshape((-1, 1)) + if y_pred.ndim == 1: + y_pred = y_pred.reshape((-1, 1)) + + # TODO: y_type, multioutput + return None, y_true, y_pred, multioutput + + +@derived_from(sklearn.metrics) +def r2_score( + y_true: ArrayLike, + y_pred: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + multioutput: Optional[str] = "uniform_average", + compute: bool = True, +) -> ArrayLike: + _check_sample_weight(sample_weight) + _, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, multioutput) + weight = 1.0 + + numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8") + denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8") + + nonzero_denominator = denominator != 0 + nonzero_numerator = numerator != 0 + valid_score = nonzero_denominator & nonzero_numerator + output_chunks = getattr(y_true, "chunks", [None, None])[1] + output_scores = da.ones([y_true.shape[1]], chunks=output_chunks) + with np.errstate(all="ignore"): + output_scores[valid_score] = 1 - ( + numerator[valid_score] / denominator[valid_score] + ) + output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0 + + result = output_scores.mean(axis=0) + if compute: + result = result.compute() + return result diff --git a/dask_sql/physical/rel/custom/wrappers.py b/dask_sql/physical/rel/custom/wrappers.py index 7ed0d0dea..c6432497b 100644 --- a/dask_sql/physical/rel/custom/wrappers.py +++ b/dask_sql/physical/rel/custom/wrappers.py @@ -3,11 +3,19 @@ """Meta-estimators for parallelizing estimators using the scikit-learn API.""" import logging import warnings +from typing import Any, Callable, Tuple, Union import dask.array as da import dask.dataframe as dd import dask.delayed import numpy as np +import sklearn.base +import sklearn.metrics +from dask.delayed import Delayed +from dask.highlevelgraph import HighLevelGraph +from sklearn.metrics import check_scoring as sklearn_check_scoring +from sklearn.metrics import make_scorer +from sklearn.utils.validation import check_is_fitted try: import sklearn.base @@ -15,9 +23,31 @@ except ImportError: # pragma: no cover raise ImportError("sklearn must be installed") +from dask_sql.physical.rel.custom.metrics import ( + accuracy_score, + log_loss, + mean_squared_error, + r2_score, +) + logger = logging.getLogger(__name__) +# Scorers +accuracy_scorer: Tuple[Any, Any] = (accuracy_score, {}) +neg_mean_squared_error_scorer = (mean_squared_error, dict(greater_is_better=False)) +r2_scorer: Tuple[Any, Any] = (r2_score, {}) +neg_log_loss_scorer = (log_loss, dict(greater_is_better=False, needs_proba=True)) + + +SCORERS = dict( + accuracy=accuracy_scorer, + neg_mean_squared_error=neg_mean_squared_error_scorer, + r2=r2_scorer, + neg_log_loss=neg_log_loss_scorer, +) + + class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixin): """Meta-estimator for parallel predict and transform. @@ -231,9 +261,7 @@ def score(self, X, y, compute=True): if not dask.is_dask_collection(X) and not dask.is_dask_collection(y): scorer = sklearn.metrics.get_scorer(scoring) else: - # TODO: implement Dask-ML's get_scorer() function - # scorer = get_scorer(scoring, compute=compute) - raise NotImplementedError("get_scorer function not implemented") + scorer = get_scorer(scoring, compute=compute) return scorer(self, X, y) else: return self._postfit_estimator.score(X, y) @@ -386,6 +414,145 @@ def _check_method(self, method): return getattr(estimator, method) +class Incremental(ParallelPostFit): + """Metaestimator for feeding Dask Arrays to an estimator blockwise. + This wrapper provides a bridge between Dask objects and estimators + implementing the ``partial_fit`` API. These *incremental learners* can + train on batches of data. This fits well with Dask's blocked data + structures. + .. note:: + This meta-estimator is not appropriate for hyperparameter optimization + on larger-than-memory datasets. + See the `list of incremental learners`_ in the scikit-learn documentation + for a list of estimators that implement the ``partial_fit`` API. Note that + `Incremental` is not limited to just these classes, it will work on any + estimator implementing ``partial_fit``, including those defined outside of + scikit-learn itself. + Calling :meth:`Incremental.fit` with a Dask Array will pass each block of + the Dask array or arrays to ``estimator.partial_fit`` *sequentially*. + Like :class:`ParallelPostFit`, the methods available after fitting (e.g. + :meth:`Incremental.predict`, etc.) are all parallel and delayed. + The ``estimator_`` attribute is a clone of `estimator` that was actually + used during the call to ``fit``. All attributes learned during training + are available on ``Incremental`` directly. + .. _list of incremental learners: https://scikit-learn.org/stable/modules/computing.html#incremental-learning # noqa + Parameters + ---------- + estimator : Estimator + Any object supporting the scikit-learn ``partial_fit`` API. + scoring : string or callable, optional + A single string (see :ref:`scoring_parameter`) or a callable + (see :ref:`scoring`) to evaluate the predictions on the test set. + For evaluating multiple metrics, either give a list of (unique) + strings or a dict with names as keys and callables as values. + NOTE that when using custom scorers, each scorer should return a + single value. Metric functions returning a list/array of values + can be wrapped into multiple scorers that return one value each. + See :ref:`multimetric_grid_search` for an example. + .. warning:: + If None, the estimator's default scorer (if available) is used. + Most scikit-learn estimators will convert large Dask arrays to + a single NumPy array, which may exhaust the memory of your worker. + You probably want to always specify `scoring`. + random_state : int or numpy.random.RandomState, optional + Random object that determines how to shuffle blocks. + shuffle_blocks : bool, default True + Determines whether to call ``partial_fit`` on a randomly selected chunk + of the Dask arrays (default), or to fit in sequential order. This does + not control shuffle between blocks or shuffling each block. + predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``predict`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``predict_proba`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer) + An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output + type of the estimators ``transform`` call. + This meta is necessary for for some estimators to work with + ``dask.dataframe`` and ``dask.array`` + Attributes + ---------- + estimator_ : Estimator + A clone of `estimator` that was actually fit during the ``.fit`` call. + + """ + + def __init__( + self, + estimator=None, + scoring=None, + shuffle_blocks=True, + random_state=None, + assume_equal_chunks=True, + predict_meta=None, + predict_proba_meta=None, + transform_meta=None, + ): + self.shuffle_blocks = shuffle_blocks + self.random_state = random_state + self.assume_equal_chunks = assume_equal_chunks + super(Incremental, self).__init__( + estimator=estimator, + scoring=scoring, + predict_meta=predict_meta, + predict_proba_meta=predict_proba_meta, + transform_meta=transform_meta, + ) + + @property + def _postfit_estimator(self): + check_is_fitted(self, "estimator_") + return self.estimator_ + + def _fit_for_estimator(self, estimator, X, y, **fit_kwargs): + check_scoring(estimator, self.scoring) + if not dask.is_dask_collection(X) and not dask.is_dask_collection(y): + result = estimator.partial_fit(X=X, y=y, **fit_kwargs) + else: + result = fit( + estimator, + X, + y, + random_state=self.random_state, + shuffle_blocks=self.shuffle_blocks, + assume_equal_chunks=self.assume_equal_chunks, + **fit_kwargs, + ) + + copy_learned_attributes(result, self) + self.estimator_ = result + return self + + def fit(self, X, y=None, **fit_kwargs): + estimator = sklearn.base.clone(self.estimator) + self._fit_for_estimator(estimator, X, y, **fit_kwargs) + return self + + def partial_fit(self, X, y=None, **fit_kwargs): + """Fit the underlying estimator. + If this estimator has not been previously fit, this is identical to + :meth:`Incremental.fit`. If it has been previously fit, + ``self.estimator_`` is used as the starting point. + Parameters + ---------- + X, y : array-like + **kwargs + Additional fit-kwargs for the underlying estimator. + Returns + ------- + self : object + """ + estimator = getattr(self, "estimator_", None) + if estimator is None: + estimator = sklearn.base.clone(self.estimator) + return self._fit_for_estimator(estimator, X, y, **fit_kwargs) + + def _predict(part, estimator, output_meta=None): if part.shape[0] == 0 and output_meta is not None: empty_output = handle_empty_partitions(output_meta) @@ -495,3 +662,147 @@ def copy_learned_attributes(from_estimator, to_estimator): for k, v in attrs.items(): setattr(to_estimator, k, v) + + +def get_scorer(scoring: Union[str, Callable], compute: bool = True) -> Callable: + """Get a scorer from string + Parameters + ---------- + scoring : str | callable + scoring method as string. If callable it is returned as is. + Returns + ------- + scorer : callable + The scorer. + """ + # This is the same as sklearns, only we use our SCORERS dict, + # and don't have back-compat code + if isinstance(scoring, str): + try: + scorer, kwargs = SCORERS[scoring] + except KeyError: + raise ValueError( + "{} is not a valid scoring value. " + "Valid options are {}".format(scoring, sorted(SCORERS)) + ) + else: + scorer = scoring + kwargs = {} + + kwargs["compute"] = compute + + return make_scorer(scorer, **kwargs) + + +def check_scoring(estimator, scoring=None, **kwargs): + res = sklearn_check_scoring(estimator, scoring=scoring, **kwargs) + if scoring in SCORERS.keys(): + func, kwargs = SCORERS[scoring] + return make_scorer(func, **kwargs) + return res + + +def fit( + model, + x, + y, + compute=True, + shuffle_blocks=True, + random_state=None, + assume_equal_chunks=False, + **kwargs, +): + """Fit scikit learn model against dask arrays + Model must support the ``partial_fit`` interface for online or batch + learning. + Ideally your rows are independent and identically distributed. By default, + this function will step through chunks of the arrays in random order. + Parameters + ---------- + model: sklearn model + Any model supporting partial_fit interface + x: dask Array + Two dimensional array, likely tall and skinny + y: dask Array + One dimensional array with same chunks as x's rows + compute : bool + Whether to compute this result + shuffle_blocks : bool + Whether to shuffle the blocks with ``random_state`` or not + random_state : int or numpy.random.RandomState + Random state to use when shuffling blocks + kwargs: + options to pass to partial_fit + """ + + nblocks, x_name = _blocks_and_name(x) + if y is not None: + y_nblocks, y_name = _blocks_and_name(y) + assert y_nblocks == nblocks + else: + y_name = "" + + if not hasattr(model, "partial_fit"): + msg = "The class '{}' does not implement 'partial_fit'." + raise ValueError(msg.format(type(model))) + + order = list(range(nblocks)) + if shuffle_blocks: + rng = sklearn.utils.check_random_state(random_state) + rng.shuffle(order) + + name = "fit-" + dask.base.tokenize(model, x, y, kwargs, order) + + if hasattr(x, "chunks") and x.ndim > 1: + x_extra = (0,) + else: + x_extra = () + + dsk = {(name, -1): model} + dsk.update( + { + (name, i): ( + _partial_fit, + (name, i - 1), + (x_name, order[i]) + x_extra, + (y_name, order[i]), + kwargs, + ) + for i in range(nblocks) + } + ) + + dependencies = [x] + if y is not None: + dependencies.append(y) + new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) + value = Delayed((name, nblocks - 1), new_dsk, layer=name) + + if compute: + return value.compute() + else: + return value + + +def _blocks_and_name(obj): + if hasattr(obj, "chunks"): + nblocks = len(obj.chunks[0]) + name = obj.name + + elif hasattr(obj, "npartitions"): + # dataframe, bag + nblocks = obj.npartitions + if hasattr(obj, "_name"): + # dataframe + name = obj._name + else: + # bag + name = obj.name + + return nblocks, name + + +def _partial_fit(model, x, y, kwargs=None): + kwargs = kwargs or dict() + model.partial_fit(x, y, **kwargs) + return model diff --git a/docs/source/sql/ml.rst b/docs/source/sql/ml.rst index 5c3a3b9d1..7c388d1e7 100644 --- a/docs/source/sql/ml.rst +++ b/docs/source/sql/ml.rst @@ -48,9 +48,6 @@ The key-value parameters control, how and which model is trained: It is the full python module path to the class of the model to train. Any model class with sklearn interface is valid, but might or might not work well with Dask dataframes. - Have a look into the - `dask-ml documentation `_ - for more information on which models work best. You might need to install necessary packages to use the models. * ``target_column``: @@ -63,17 +60,13 @@ The key-value parameters control, how and which model is trained: * ``wrap_predict``: Boolean flag, whether to wrap the selected model with a :class:`dask_sql.physical.rel.custom.wrappers.ParallelPostFit`. - Have a look into the - `dask-ml docu on ParallelPostFit `_ - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if predicting on big data. + Defaults to false. Typically you set it to true for + sklearn models if predicting on big data. * ``wrap_fit``: Boolean flag, whether to wrap the selected - model with a :class:`dask_ml.wrappers.Incremental`. - Have a look into the - `dask-ml docu on Incremental `_ - to learn more about it. Defaults to false. Typically you set - it to true for sklearn models if training on big data. + model with a :class:`dask_sql.physical.rel.custom.wrappers.Incremental`. + Defaults to false. Typically you set it to true for + sklearn models if training on big data. * ``fit_kwargs``: keyword arguments sent to the call to ``fit()``. @@ -85,7 +78,7 @@ Example: .. raw:: html
CREATE MODEL my_model WITH (
-        model_class = 'dask_ml.xgboost.XGBClassifier',
+        model_class = 'xgboost.XGBClassifier',
         target_column = 'target'
     ) AS (
         SELECT x, y, target
@@ -104,11 +97,10 @@ prediction, depending if your model can cope with
 dask dataframes.
 
     * if you are training on relatively small amounts
-      of data but predicting on large data samples
-      (and you are not using a model build for usage with dask
-      from the dask-ml package), you might want to set
-      ``wrap_predict`` to True. With this option,
-      model interference will be parallelized/distributed.
+      of data but predicting on large data samples,
+      you might want to set ``wrap_predict`` to True.
+      With this option, model interference will be
+      parallelized/distributed.
     * If you are training on large amounts of data,
       you can try setting wrap_fit to True. This will
       do the same on the training step, but works only on
diff --git a/tests/unit/test_ml_wrappers.py b/tests/unit/test_ml_wrappers.py
new file mode 100644
index 000000000..97277c1ad
--- /dev/null
+++ b/tests/unit/test_ml_wrappers.py
@@ -0,0 +1,250 @@
+# Copyright 2017, Dask developers
+# Dask-ML project - https://github.com/dask/dask-ml
+from collections.abc import Sequence
+
+import dask
+import dask.array as da
+import dask.dataframe as dd
+import numpy as np
+import pandas as pd
+import pytest
+from dask.array.utils import assert_eq as assert_eq_ar
+from dask.dataframe.utils import assert_eq as assert_eq_df
+from sklearn.base import clone
+from sklearn.decomposition import PCA
+from sklearn.ensemble import GradientBoostingClassifier
+from sklearn.linear_model import LogisticRegression, SGDClassifier
+
+from dask_sql.physical.rel.custom.wrappers import Incremental, ParallelPostFit
+
+
+def _check_axis_partitioning(chunks, n_features):
+    c = chunks[1][0]
+    if c != n_features:
+        msg = (
+            "Can only generate arrays partitioned along the "
+            "first axis. Specifying a larger chunksize for "
+            "the second axis.\n\n\tchunk size: {}\n"
+            "\tn_features: {}".format(c, n_features)
+        )
+        raise ValueError(msg)
+
+
+def check_random_state(random_state):
+    if random_state is None:
+        return da.random.RandomState()
+    # elif isinstance(random_state, Integral):
+    #     return da.random.RandomState(random_state)
+    elif isinstance(random_state, np.random.RandomState):
+        return da.random.RandomState(random_state.randint())
+    elif isinstance(random_state, da.random.RandomState):
+        return random_state
+    else:
+        raise TypeError("Unexpected type '{}'".format(type(random_state)))
+
+
+def make_classification(
+    n_samples=100,
+    n_features=20,
+    n_informative=2,
+    n_classes=2,
+    scale=1.0,
+    random_state=None,
+    chunks=None,
+):
+    chunks = da.core.normalize_chunks(chunks, (n_samples, n_features))
+    _check_axis_partitioning(chunks, n_features)
+
+    if n_classes != 2:
+        raise NotImplementedError("n_classes != 2 is not yet supported.")
+
+    rng = check_random_state(random_state)
+
+    X = rng.normal(0, 1, size=(n_samples, n_features), chunks=chunks)
+    informative_idx = rng.choice(n_features, n_informative, chunks=n_informative)
+    beta = (rng.random(n_features, chunks=n_features) - 1) * scale
+
+    informative_idx, beta = dask.compute(
+        informative_idx, beta, scheduler="single-threaded"
+    )
+
+    z0 = X[:, informative_idx].dot(beta[informative_idx])
+    y = rng.random(z0.shape, chunks=chunks[0]) < 1 / (1 + da.exp(-z0))
+    y = y.astype(int)
+
+    return X, y
+
+
+def _assert_eq(l, r, name=None, **kwargs):
+    array_types = (np.ndarray, da.Array)
+    frame_types = (pd.core.generic.NDFrame, dd._Frame)
+    if isinstance(l, array_types):
+        assert_eq_ar(l, r, **kwargs)
+    elif isinstance(l, frame_types):
+        assert_eq_df(l, r, **kwargs)
+    elif isinstance(l, Sequence) and any(
+        isinstance(x, array_types + frame_types) for x in l
+    ):
+        for a, b in zip(l, r):
+            _assert_eq(a, b, **kwargs)
+    elif np.isscalar(r) and np.isnan(r):
+        assert np.isnan(l), (name, l, r)
+    else:
+        assert l == r, (name, l, r)
+
+
+def assert_estimator_equal(left, right, exclude=None, **kwargs):
+    """Check that two Estimators are equal
+    Parameters
+    ----------
+    left, right : Estimators
+    exclude : str or sequence of str
+        attributes to skip in the check
+    kwargs : dict
+        Passed through to the dask `assert_eq` method.
+    """
+    left_attrs = [x for x in dir(left) if x.endswith("_") and not x.startswith("_")]
+    right_attrs = [x for x in dir(right) if x.endswith("_") and not x.startswith("_")]
+    if exclude is None:
+        exclude = set()
+    elif isinstance(exclude, str):
+        exclude = {exclude}
+    else:
+        exclude = set(exclude)
+
+    left_attrs2 = set(left_attrs) - exclude
+    right_attrs2 = set(right_attrs) - exclude
+
+    assert left_attrs2 == right_attrs2, left_attrs2 ^ right_attrs2
+
+    for attr in left_attrs2:
+        l = getattr(left, attr)
+        r = getattr(right, attr)
+        _assert_eq(l, r, name=attr, **kwargs)
+
+
+def test_parallelpostfit_basic():
+    clf = ParallelPostFit(GradientBoostingClassifier())
+
+    X, y = make_classification(n_samples=1000, chunks=100)
+    X_, y_ = dask.compute(X, y)
+    clf.fit(X_, y_)
+
+    assert isinstance(clf.predict(X), da.Array)
+    assert isinstance(clf.predict_proba(X), da.Array)
+
+    result = clf.score(X, y)
+    expected = clf.estimator.score(X_, y_)
+    assert result == expected
+
+
+@pytest.mark.parametrize("kind", ["numpy", "dask.dataframe", "dask.array"])
+def test_predict(kind):
+    X, y = make_classification(chunks=100)
+
+    if kind == "numpy":
+        X, y = dask.compute(X, y)
+    elif kind == "dask.dataframe":
+        X = dd.from_dask_array(X)
+        y = dd.from_dask_array(y)
+
+    base = LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs")
+    wrap = ParallelPostFit(
+        LogisticRegression(random_state=0, n_jobs=1, solver="lbfgs"),
+    )
+
+    base.fit(*dask.compute(X, y))
+    wrap.fit(*dask.compute(X, y))
+
+    assert_estimator_equal(wrap.estimator, base)
+
+    result = wrap.predict(X)
+    expected = base.predict(X)
+    assert_eq_ar(result, expected)
+
+    result = wrap.predict_proba(X)
+    expected = base.predict_proba(X)
+    assert_eq_ar(result, expected)
+
+    result = wrap.predict_log_proba(X)
+    expected = base.predict_log_proba(X)
+    assert_eq_ar(result, expected)
+
+
+@pytest.mark.parametrize("kind", ["numpy", "dask.dataframe", "dask.array"])
+def test_transform(kind):
+    X, y = make_classification(chunks=100)
+
+    if kind == "numpy":
+        X, y = dask.compute(X, y)
+    elif kind == "dask.dataframe":
+        X = dd.from_dask_array(X)
+        y = dd.from_dask_array(y)
+
+    base = PCA(random_state=0)
+    wrap = ParallelPostFit(PCA(random_state=0))
+
+    base.fit(*dask.compute(X, y))
+    wrap.fit(*dask.compute(X, y))
+
+    assert_estimator_equal(wrap.estimator, base)
+
+    result = base.transform(*dask.compute(X))
+    expected = wrap.transform(X)
+    assert_eq_ar(result, expected)
+
+
+@pytest.mark.parametrize("dataframes", [False, True])
+def test_incremental_basic(dataframes):
+    # Create observations that we know linear models can recover
+    n, d = 100, 3
+    rng = da.random.RandomState(42)
+    X = rng.normal(size=(n, d), chunks=30)
+    coef_star = rng.uniform(size=d, chunks=d)
+    y = da.sign(X.dot(coef_star))
+    y = (y + 1) / 2
+    if dataframes:
+        X = dd.from_array(X)
+        y = dd.from_array(y)
+
+    est1 = SGDClassifier(random_state=0, tol=1e-3, average=True)
+    est2 = clone(est1)
+
+    clf = Incremental(est1, random_state=0)
+    result = clf.fit(X, y, classes=[0, 1])
+    assert result is clf
+
+    # est2 is a sklearn optimizer; this is just a benchmark
+    if dataframes:
+        X = X.to_dask_array(lengths=True)
+        y = y.to_dask_array(lengths=True)
+
+    for slice_ in da.core.slices_from_chunks(X.chunks):
+        est2.partial_fit(X[slice_].compute(), y[slice_[0]].compute(), classes=[0, 1])
+
+    assert isinstance(result.estimator_.coef_, np.ndarray)
+    rel_error = np.linalg.norm(clf.coef_ - est2.coef_)
+    rel_error /= np.linalg.norm(clf.coef_)
+    assert rel_error < 0.9
+
+    assert set(dir(clf.estimator_)) == set(dir(est2))
+
+    #  Predict
+    result = clf.predict(X)
+    expected = est2.predict(X)
+    assert isinstance(result, da.Array)
+    if dataframes:
+        # Compute is needed because chunk sizes of this array are unknown
+        result = result.compute()
+    rel_error = np.linalg.norm(result - expected)
+    rel_error /= np.linalg.norm(expected)
+    assert rel_error < 0.3
+
+    # score
+    result = clf.score(X, y)
+    expected = est2.score(*dask.compute(X, y))
+    assert abs(result - expected) < 0.1
+
+    clf = Incremental(SGDClassifier(random_state=0, tol=1e-3, average=True))
+    clf.partial_fit(X, y, classes=[0, 1])
+    assert set(dir(clf.estimator_)) == set(dir(est2))

From 2697742a158c73f41196de7dcb850b31fec3567d Mon Sep 17 00:00:00 2001
From: Ayush Dattagupta 
Date: Mon, 14 Nov 2022 16:12:22 -0800
Subject: [PATCH 14/19] Update flake8 link to use github (#915)

* Update flake8 link to use github

* Update black, isort and flake8 versions for style checks
---
 .pre-commit-config.yaml | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0bda454e2..6b5c93ec5 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,16 +1,16 @@
 repos:
   - repo: https://github.com/psf/black
-    rev: 22.3.0
+    rev: 22.10.0
     hooks:
       - id: black
         language_version: python3
-  - repo: https://gitlab.com/pycqa/flake8
-    rev: 3.9.2
+  - repo: https://github.com/PyCQA/flake8
+    rev: 5.0.4
     hooks:
       - id: flake8
         language_version: python3
   - repo: https://github.com/pycqa/isort
-    rev: 5.7.0
+    rev: 5.10.1
     hooks:
       - id: isort
         args:

From 0edf6e1851fa8edafa2b341bf71bc4eb53550038 Mon Sep 17 00:00:00 2001
From: jakirkham 
Date: Tue, 15 Nov 2022 06:46:42 -0800
Subject: [PATCH 15/19] Use `conda-incubator/setup-miniconda@v2.2.0` & enable
 automatic GH Action updates (#917)

* Use `conda-incubator/setup-miniconda@v2.2.0`

* Handle GH Action updates with Dependabot
---
 .github/dependabot.yml              | 5 +++++
 .github/workflows/conda.yml         | 2 +-
 .github/workflows/release.yml       | 2 +-
 .github/workflows/test-upstream.yml | 6 +++---
 .github/workflows/test.yml          | 6 +++---
 5 files changed, 13 insertions(+), 8 deletions(-)

diff --git a/.github/dependabot.yml b/.github/dependabot.yml
index 5fe73017d..4d0e1d08e 100644
--- a/.github/dependabot.yml
+++ b/.github/dependabot.yml
@@ -10,3 +10,8 @@ updates:
         update-types: ["version-update:semver-major"]
       - dependency-name: "datafusion-*"
         update-types: ["version-update:semver-major"]
+  - package-ecosystem: "github-actions"
+    directory: "/"
+    schedule:
+      # Check for updates to GitHub Actions every weekday
+      interval: "weekly"
diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml
index 78047147a..566895bdb 100644
--- a/.github/workflows/conda.yml
+++ b/.github/workflows/conda.yml
@@ -36,7 +36,7 @@ jobs:
         with:
           fetch-depth: 0
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           miniforge-variant: Mambaforge
           use-mamba: true
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index b4e3ac7ed..f6648d04e 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -50,7 +50,7 @@ jobs:
           output-dir: dist
           config-file: "dask_planner/pyproject.toml"
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           miniforge-variant: Mambaforge
           use-mamba: true
diff --git a/.github/workflows/test-upstream.yml b/.github/workflows/test-upstream.yml
index 1e3e5caa9..79678cf59 100644
--- a/.github/workflows/test-upstream.yml
+++ b/.github/workflows/test-upstream.yml
@@ -50,7 +50,7 @@ jobs:
         with:
           fetch-depth: 0 # Fetch all history for all branches and tags.
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           miniforge-variant: Mambaforge
           use-mamba: true
@@ -88,7 +88,7 @@ jobs:
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           miniforge-variant: Mambaforge
           use-mamba: true
@@ -140,7 +140,7 @@ jobs:
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           python-version: "3.8"
           mamba-version: "*"
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index a20ee2b14..8b330acd5 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -46,7 +46,7 @@ jobs:
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           miniforge-variant: Mambaforge
           use-mamba: true
@@ -89,7 +89,7 @@ jobs:
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           miniforge-variant: Mambaforge
           use-mamba: true
@@ -139,7 +139,7 @@ jobs:
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
-        uses: conda-incubator/setup-miniconda@v2
+        uses: conda-incubator/setup-miniconda@v2.2.0
         with:
           python-version: "3.8"
           mamba-version: "*"

From c7017a77ae9102cbfcd83170d2f1c83181ae8a19 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Tue, 15 Nov 2022 09:54:01 -0500
Subject: [PATCH 16/19] Bump uuid from 1.2.1 to 1.2.2 in /dask_planner (#916)

Bumps [uuid](https://github.com/uuid-rs/uuid) from 1.2.1 to 1.2.2.
- [Release notes](https://github.com/uuid-rs/uuid/releases)
- [Commits](https://github.com/uuid-rs/uuid/compare/1.2.1...1.2.2)

---
updated-dependencies:
- dependency-name: uuid
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] 

Signed-off-by: dependabot[bot] 
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
---
 dask_planner/Cargo.lock | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock
index b6586e238..44022646e 100644
--- a/dask_planner/Cargo.lock
+++ b/dask_planner/Cargo.lock
@@ -1232,9 +1232,9 @@ checksum = "58ee9362deb4a96cef4d437d1ad49cffc9b9e92d202b6995674e928ce684f112"
 
 [[package]]
 name = "uuid"
-version = "1.2.1"
+version = "1.2.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "feb41e78f93363bb2df8b0e86a2ca30eed7806ea16ea0c790d757cf93f79be83"
+checksum = "422ee0de9031b5b948b97a8fc04e3aa35230001a722ddd27943e0be31564ce4c"
 dependencies = [
  "getrandom",
 ]

From ab246b004a7147a63c567c455116f498c2a6d15a Mon Sep 17 00:00:00 2001
From: Andy Grove 
Date: Tue, 15 Nov 2022 11:39:11 -0700
Subject: [PATCH 17/19] Upgrade to DataFusion 14.0.0 (#903)

* upgrade to latest datafusion rev

* panic on unexpected value

* remove panic

* fix regression with window functions

* fix regression

* use official release of DataFusion

* update optimizer rules list

* add filter_push_down rule from DataFusion 13

* fix

* add expr simplifier rule but without optimization for rewriting small 'in' expressions

* remove unused imports

* Disable EliminateFilter optimization to unblock regressions

* Use upstream SimplifyExpressions, catch associated KeyError

* Forbid auto-index setting in attempt_predicate_pushdown

* Ignore index in test_predicate_pushdown

* Add dask version check to predicate pushdown tests

* Add TODO for index specification

Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com>
---
 dask_planner/Cargo.lock                       | 313 ++++++---
 dask_planner/Cargo.toml                       |  10 +-
 dask_planner/src/sql.rs                       |   2 +-
 dask_planner/src/sql/logical.rs               |   1 +
 dask_planner/src/sql/logical/join.rs          |   6 +-
 dask_planner/src/sql/logical/window.rs        |  72 +-
 dask_planner/src/sql/optimizer.rs             | 109 ++-
 .../sql/optimizer/eliminate_agg_distinct.rs   |   2 +-
 .../src/sql/optimizer/filter_push_down.rs     | 641 ++++++++++++++++++
 dask_planner/src/sql/types.rs                 |   2 +-
 dask_sql/physical/rel/logical/join.py         |   2 +-
 dask_sql/physical/utils/filter.py             |   3 +-
 tests/integration/test_filter.py              |  10 +-
 13 files changed, 980 insertions(+), 193 deletions(-)
 create mode 100644 dask_planner/src/sql/optimizer/filter_push_down.rs

diff --git a/dask_planner/Cargo.lock b/dask_planner/Cargo.lock
index 44022646e..ecce581d2 100644
--- a/dask_planner/Cargo.lock
+++ b/dask_planner/Cargo.lock
@@ -15,9 +15,9 @@ dependencies = [
 
 [[package]]
 name = "ahash"
-version = "0.8.0"
+version = "0.8.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "57e6e951cfbb2db8de1828d49073a113a29fd7117b1596caa781a258c7e38d72"
+checksum = "464b3811b747f8f7ebc8849c9c728c39f6ac98a055edad93baf9eb330e3f8f9d"
 dependencies = [
  "cfg-if",
  "const-random",
@@ -58,15 +58,16 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
 
 [[package]]
 name = "arrow"
-version = "25.0.0"
+version = "26.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "76312eb67808c67341f4234861c4fcd2f9868f55e88fa2186ab3b357a6c5830b"
+checksum = "e24e2bcd431a4aa0ff003fdd2dc21c78cfb42f31459c89d2312c2746fe17a5ac"
 dependencies = [
- "ahash 0.8.0",
+ "ahash 0.8.1",
  "arrow-array",
  "arrow-buffer",
  "arrow-data",
  "arrow-schema",
+ "arrow-select",
  "bitflags",
  "chrono",
  "comfy-table",
@@ -86,11 +87,11 @@ dependencies = [
 
 [[package]]
 name = "arrow-array"
-version = "25.0.0"
+version = "26.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "69dd2c257fa76de0bcc63cabe8c81d34c46ef6fa7651e3e497922c3c9878bd67"
+checksum = "c9044300874385f19e77cbf90911e239bd23630d8f23bb0f948f9067998a13b7"
 dependencies = [
- "ahash 0.8.0",
+ "ahash 0.8.1",
  "arrow-buffer",
  "arrow-data",
  "arrow-schema",
@@ -102,9 +103,9 @@ dependencies = [
 
 [[package]]
 name = "arrow-buffer"
-version = "25.0.0"
+version = "26.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "af963e71bdbbf928231d521083ddc8e8068cf5c8d45d4edcfeaf7eb5cdd779a9"
+checksum = "78476cbe9e3f808dcecab86afe42d573863c63e149c62e6e379ed2522743e626"
 dependencies = [
  "half",
  "num",
@@ -112,9 +113,9 @@ dependencies = [
 
 [[package]]
 name = "arrow-data"
-version = "25.0.0"
+version = "26.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52554ffff560c366d7210c2621a3cf1dc408f9969a0c7688a3ba0a62248a945d"
+checksum = "4d916feee158c485dad4f701cba31bc9a90a8db87d9df8e2aa8adc0c20a2bbb9"
 dependencies = [
  "arrow-buffer",
  "arrow-schema",
@@ -124,9 +125,22 @@ dependencies = [
 
 [[package]]
 name = "arrow-schema"
-version = "25.0.0"
+version = "26.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1a5518f2bd7775057391f88257627cbb760ba3e1c2f2444a005ba79158624654"
+checksum = "0f9406eb7834ca6bd8350d1baa515d18b9fcec487eddacfb62f5e19511f7bd37"
+
+[[package]]
+name = "arrow-select"
+version = "26.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6593a01586751c74498495d2f5a01fcd438102b52965c11dd98abf4ebcacef37"
+dependencies = [
+ "arrow-array",
+ "arrow-buffer",
+ "arrow-data",
+ "arrow-schema",
+ "num",
+]
 
 [[package]]
 name = "async-trait"
@@ -208,15 +222,15 @@ dependencies = [
 
 [[package]]
 name = "bumpalo"
-version = "3.11.0"
+version = "3.11.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d"
+checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba"
 
 [[package]]
 name = "cc"
-version = "1.0.73"
+version = "1.0.74"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11"
+checksum = "581f5dba903aac52ea3feb5ec4810848460ee833876f1f9b0fdeab1f19091574"
 
 [[package]]
 name = "cfg-if"
@@ -236,11 +250,21 @@ dependencies = [
  "winapi",
 ]
 
+[[package]]
+name = "codespan-reporting"
+version = "0.11.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3538270d33cc669650c4b093848450d380def10c331d38c768e34cac80576e6e"
+dependencies = [
+ "termcolor",
+ "unicode-width",
+]
+
 [[package]]
 name = "comfy-table"
-version = "6.1.0"
+version = "6.1.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "85914173c2f558d61613bfbbf1911f14e630895087a7ed2fafc0f5319e1536e7"
+checksum = "1090f39f45786ec6dc6286f8ea9c75d0a7ef0a0d3cda674cef0c3af7b307fbc2"
 dependencies = [
  "strum",
  "strum_macros",
@@ -249,9 +273,9 @@ dependencies = [
 
 [[package]]
 name = "const-random"
-version = "0.1.13"
+version = "0.1.15"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f590d95d011aa80b063ffe3253422ed5aa462af4e9867d43ce8337562bac77c4"
+checksum = "368a7a772ead6ce7e1de82bfb04c485f3db8ec744f72925af5735e29a22cc18e"
 dependencies = [
  "const-random-macro",
  "proc-macro-hack",
@@ -259,12 +283,12 @@ dependencies = [
 
 [[package]]
 name = "const-random-macro"
-version = "0.1.13"
+version = "0.1.15"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "615f6e27d000a2bffbc7f2f6a8669179378fa27ee4d0a509e985dfc0a7defb40"
+checksum = "9d7d6ab3c3a2282db210df5f02c4dab6e0a7057af0fb7ebd4070f30fe05c0ddb"
 dependencies = [
  "getrandom",
- "lazy_static",
+ "once_cell",
  "proc-macro-hack",
  "tiny-keccak",
 ]
@@ -328,6 +352,50 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "cxx"
+version = "1.0.80"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6b7d4e43b25d3c994662706a1d4fcfc32aaa6afd287502c111b237093bb23f3a"
+dependencies = [
+ "cc",
+ "cxxbridge-flags",
+ "cxxbridge-macro",
+ "link-cplusplus",
+]
+
+[[package]]
+name = "cxx-build"
+version = "1.0.80"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "84f8829ddc213e2c1368e51a2564c552b65a8cb6a28f31e576270ac81d5e5827"
+dependencies = [
+ "cc",
+ "codespan-reporting",
+ "once_cell",
+ "proc-macro2",
+ "quote",
+ "scratch",
+ "syn",
+]
+
+[[package]]
+name = "cxxbridge-flags"
+version = "1.0.80"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e72537424b474af1460806647c41d4b6d35d09ef7fe031c5c2fa5766047cc56a"
+
+[[package]]
+name = "cxxbridge-macro"
+version = "1.0.80"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "309e4fb93eed90e1e14bea0da16b209f81813ba9fc7830c20ed151dd7bc0a4d7"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
 [[package]]
 name = "dask_planner"
 version = "0.1.0"
@@ -350,20 +418,23 @@ dependencies = [
 
 [[package]]
 name = "datafusion-common"
-version = "13.0.0"
-source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2"
+version = "14.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "15f1ffcbc1f040c9ab99f41db1c743d95aff267bb2e7286aaa010738b7402251"
 dependencies = [
  "arrow",
+ "chrono",
  "ordered-float",
  "sqlparser",
 ]
 
 [[package]]
 name = "datafusion-expr"
-version = "13.0.0"
-source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2"
+version = "14.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1883d9590d303ef38fa295567e7fdb9f8f5f511fcc167412d232844678cd295c"
 dependencies = [
- "ahash 0.8.0",
+ "ahash 0.8.1",
  "arrow",
  "datafusion-common",
  "log",
@@ -372,8 +443,9 @@ dependencies = [
 
 [[package]]
 name = "datafusion-optimizer"
-version = "13.0.0"
-source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2"
+version = "14.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2127d46d566ab3463d70da9675fc07b9d634be8d17e80d0e1ce79600709fe651"
 dependencies = [
  "arrow",
  "async-trait",
@@ -387,32 +459,40 @@ dependencies = [
 
 [[package]]
 name = "datafusion-physical-expr"
-version = "13.0.0"
-source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2"
+version = "14.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0d108b6fe8eeb317ecad1d74619e8758de49cccc8c771b56c97962fd52eaae23"
 dependencies = [
- "ahash 0.8.0",
+ "ahash 0.8.1",
  "arrow",
+ "arrow-buffer",
+ "arrow-schema",
  "blake2",
  "blake3",
  "chrono",
  "datafusion-common",
  "datafusion-expr",
  "datafusion-row",
+ "half",
  "hashbrown",
+ "itertools",
  "lazy_static",
  "md-5",
+ "num-traits",
  "ordered-float",
  "paste",
  "rand",
  "regex",
  "sha2",
  "unicode-segmentation",
+ "uuid",
 ]
 
 [[package]]
 name = "datafusion-row"
-version = "13.0.0"
-source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2"
+version = "14.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "43537b6377d506e4788bf21e9ed943340e076b48ca4d077e6ea4405ca5e54a1c"
 dependencies = [
  "arrow",
  "datafusion-common",
@@ -422,8 +502,9 @@ dependencies = [
 
 [[package]]
 name = "datafusion-sql"
-version = "13.0.0"
-source = "git+https://github.com/apache/arrow-datafusion/?rev=54d2870a56d8d8f914a617a7e2d52e387ef5dba2#54d2870a56d8d8f914a617a7e2d52e387ef5dba2"
+version = "14.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "244d08d4710e1088d9c0949c9b5b8d68d9cf2cde7203134a4cc389e870fe2354"
 dependencies = [
  "arrow",
  "datafusion-common",
@@ -442,6 +523,12 @@ dependencies = [
  "subtle",
 ]
 
+[[package]]
+name = "either"
+version = "1.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "90e5c1c8368803113bf0c9584fc495a58b86dc8a29edbf8fe877d21d9507e797"
+
 [[package]]
 name = "env_logger"
 version = "0.9.3"
@@ -457,12 +544,11 @@ dependencies = [
 
 [[package]]
 name = "flatbuffers"
-version = "2.1.2"
+version = "22.9.29"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "86b428b715fdbdd1c364b84573b5fdc0f84f8e423661b9f398735278bc7f2b6a"
+checksum = "8ce016b9901aef3579617931fbb2df8fc9a9f7cb95a16eb8acc8148209bb9e70"
 dependencies = [
  "bitflags",
- "smallvec",
  "thiserror",
 ]
 
@@ -478,13 +564,15 @@ dependencies = [
 
 [[package]]
 name = "getrandom"
-version = "0.2.7"
+version = "0.2.8"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6"
+checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31"
 dependencies = [
  "cfg-if",
+ "js-sys",
  "libc",
  "wasi",
+ "wasm-bindgen",
 ]
 
 [[package]]
@@ -529,17 +617,28 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
 
 [[package]]
 name = "iana-time-zone"
-version = "0.1.50"
+version = "0.1.53"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "fd911b35d940d2bd0bea0f9100068e5b97b51a1cbe13d13382f132e0365257a0"
+checksum = "64c122667b287044802d6ce17ee2ddf13207ed924c712de9a66a5814d5b64765"
 dependencies = [
  "android_system_properties",
  "core-foundation-sys",
+ "iana-time-zone-haiku",
  "js-sys",
  "wasm-bindgen",
  "winapi",
 ]
 
+[[package]]
+name = "iana-time-zone-haiku"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0703ae284fc167426161c2e3f1da3ea71d94b21bedbcc9494e92b28e334e3dca"
+dependencies = [
+ "cxx",
+ "cxx-build",
+]
+
 [[package]]
 name = "indexmap"
 version = "1.9.1"
@@ -556,6 +655,15 @@ version = "1.0.7"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3"
 
+[[package]]
+name = "itertools"
+version = "0.10.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
+dependencies = [
+ "either",
+]
+
 [[package]]
 name = "itoa"
 version = "0.4.8"
@@ -649,9 +757,9 @@ dependencies = [
 
 [[package]]
 name = "libc"
-version = "0.2.134"
+version = "0.2.137"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "329c933548736bc49fd575ee68c89e8be4d260064184389a5b77517cddd99ffb"
+checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89"
 
 [[package]]
 name = "libm"
@@ -668,6 +776,15 @@ dependencies = [
  "cc",
 ]
 
+[[package]]
+name = "link-cplusplus"
+version = "1.0.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9272ab7b96c9046fbc5bc56c06c117cb639fe2d509df0c421cad82d2915cf369"
+dependencies = [
+ "cc",
+]
+
 [[package]]
 name = "lock_api"
 version = "0.4.9"
@@ -819,9 +936,9 @@ dependencies = [
 
 [[package]]
 name = "num_cpus"
-version = "1.13.1"
+version = "1.14.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1"
+checksum = "f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5"
 dependencies = [
  "hermit-abi",
  "libc",
@@ -829,15 +946,15 @@ dependencies = [
 
 [[package]]
 name = "once_cell"
-version = "1.15.0"
+version = "1.16.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1"
+checksum = "86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860"
 
 [[package]]
 name = "ordered-float"
-version = "3.2.0"
+version = "3.4.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "129d36517b53c461acc6e1580aeb919c8ae6708a4b1eae61c4463a615d4f0411"
+checksum = "d84eb1409416d254e4a9c8fa56cc24701755025b458f0fcd8e59e1f5f40c23bf"
 dependencies = [
  "num-traits",
 ]
@@ -854,9 +971,9 @@ dependencies = [
 
 [[package]]
 name = "parking_lot_core"
-version = "0.9.3"
+version = "0.9.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "09a279cbf25cb0757810394fbc1e359949b59e348145c643a939a525692e6929"
+checksum = "4dc9e0dc2adc1c69d09143aff38d3d30c5c3f0df0dad82e6d25547af174ebec0"
 dependencies = [
  "cfg-if",
  "libc",
@@ -879,9 +996,9 @@ checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
 
 [[package]]
 name = "ppv-lite86"
-version = "0.2.16"
+version = "0.2.17"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
+checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
 
 [[package]]
 name = "proc-macro-hack"
@@ -891,9 +1008,9 @@ checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5"
 
 [[package]]
 name = "proc-macro2"
-version = "1.0.46"
+version = "1.0.47"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "94e2ef8dbfc347b10c094890f778ee2e36ca9bb4262e86dc99cd217e35f3470b"
+checksum = "5ea3d908b0e36316caf9e9e2c4625cdde190a7e6f440d794667ed17a1855e725"
 dependencies = [
  "unicode-ident",
 ]
@@ -1008,9 +1125,9 @@ dependencies = [
 
 [[package]]
 name = "regex"
-version = "1.6.0"
+version = "1.7.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b"
+checksum = "e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a"
 dependencies = [
  "aho-corasick",
  "memchr",
@@ -1025,9 +1142,9 @@ checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
 
 [[package]]
 name = "regex-syntax"
-version = "0.6.27"
+version = "0.6.28"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244"
+checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
 
 [[package]]
 name = "rustversion"
@@ -1047,17 +1164,23 @@ version = "1.1.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
 
+[[package]]
+name = "scratch"
+version = "1.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c8132065adcfd6e02db789d9285a0deb2f3fcb04002865ab67d5fb103533898"
+
 [[package]]
 name = "serde"
-version = "1.0.145"
+version = "1.0.147"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "728eb6351430bccb993660dfffc5a72f91ccc1295abaa8ce19b27ebe4f75568b"
+checksum = "d193d69bae983fc11a79df82342761dfbf28a99fc8d203dca4c3c1b590948965"
 
 [[package]]
 name = "serde_json"
-version = "1.0.85"
+version = "1.0.87"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e55a28e3aaef9d5ce0506d0a14dbba8054ddc7e499ef522dd8b26859ec9d4a44"
+checksum = "6ce777b7b150d76b9cf60d28b55f5847135a003f7d7350c6be7a773508ce7d45"
 dependencies = [
  "itoa 1.0.4",
  "ryu",
@@ -1083,9 +1206,9 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
 
 [[package]]
 name = "sqlparser"
-version = "0.25.0"
+version = "0.26.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0781f2b6bd03e5adf065c8e772b49eaea9f640d06a1b9130330fe8bd2563f4fd"
+checksum = "86be66ea0b2b22749cfa157d16e2e84bf793e626a3375f4d378dc289fa03affb"
 dependencies = [
  "log",
 ]
@@ -1123,9 +1246,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
 
 [[package]]
 name = "syn"
-version = "1.0.102"
+version = "1.0.103"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3fcd952facd492f9be3ef0d0b7032a6e442ee9b361d4acc2b1d0c4aaa5f613a1"
+checksum = "a864042229133ada95abf3b54fdc62ef5ccabe9515b64717bcb9a1919e59445d"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -1134,9 +1257,9 @@ dependencies = [
 
 [[package]]
 name = "target-lexicon"
-version = "0.12.4"
+version = "0.12.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c02424087780c9b71cc96799eaeddff35af2bc513278cda5c99fc1f5d026d3c1"
+checksum = "9410d0f6853b1d94f0e519fb95df60f29d2c1eff2d921ffdf01a4c8a3b54f12d"
 
 [[package]]
 name = "termcolor"
@@ -1208,9 +1331,9 @@ checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
 
 [[package]]
 name = "unicode-ident"
-version = "1.0.4"
+version = "1.0.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dcc811dc4066ac62f84f11307873c4850cb653bfa9b1719cee2bd2204a4bc5dd"
+checksum = "6ceab39d59e4c9499d4e5a8ee0e2735b891bb7308ac83dfb4e80cad195c9f6f3"
 
 [[package]]
 name = "unicode-segmentation"
@@ -1338,43 +1461,57 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
 
 [[package]]
 name = "windows-sys"
-version = "0.36.1"
+version = "0.42.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
+checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
 dependencies = [
+ "windows_aarch64_gnullvm",
  "windows_aarch64_msvc",
  "windows_i686_gnu",
  "windows_i686_msvc",
  "windows_x86_64_gnu",
+ "windows_x86_64_gnullvm",
  "windows_x86_64_msvc",
 ]
 
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "41d2aa71f6f0cbe00ae5167d90ef3cfe66527d6f613ca78ac8024c3ccab9a19e"
+
 [[package]]
 name = "windows_aarch64_msvc"
-version = "0.36.1"
+version = "0.42.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
+checksum = "dd0f252f5a35cac83d6311b2e795981f5ee6e67eb1f9a7f64eb4500fbc4dcdb4"
 
 [[package]]
 name = "windows_i686_gnu"
-version = "0.36.1"
+version = "0.42.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
+checksum = "fbeae19f6716841636c28d695375df17562ca208b2b7d0dc47635a50ae6c5de7"
 
 [[package]]
 name = "windows_i686_msvc"
-version = "0.36.1"
+version = "0.42.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
+checksum = "84c12f65daa39dd2babe6e442988fc329d6243fdce47d7d2d155b8d874862246"
 
 [[package]]
 name = "windows_x86_64_gnu"
-version = "0.36.1"
+version = "0.42.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bf7b1b21b5362cbc318f686150e5bcea75ecedc74dd157d874d754a2ca44b0ed"
+
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.42.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
+checksum = "09d525d2ba30eeb3297665bd434a54297e4170c7f1a44cad4ef58095b4cd2028"
 
 [[package]]
 name = "windows_x86_64_msvc"
-version = "0.36.1"
+version = "0.42.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
+checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"
diff --git a/dask_planner/Cargo.toml b/dask_planner/Cargo.toml
index 04079c92b..9ac4e8a44 100644
--- a/dask_planner/Cargo.toml
+++ b/dask_planner/Cargo.toml
@@ -9,12 +9,12 @@ edition = "2021"
 rust-version = "1.62"
 
 [dependencies]
-arrow = { version = "25.0.0", features = ["prettyprint"] }
+arrow = { version = "26.0.0", features = ["prettyprint"] }
 async-trait = "0.1.58"
-datafusion-common = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" }
-datafusion-expr = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" }
-datafusion-optimizer = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" }
-datafusion-sql = { git = "https://github.com/apache/arrow-datafusion/", rev = "54d2870a56d8d8f914a617a7e2d52e387ef5dba2" }
+datafusion-common = "14.0.0"
+datafusion-expr = "14.0.0"
+datafusion-optimizer = "14.0.0"
+datafusion-sql = "14.0.0"
 env_logger = "0.9"
 log = "^0.4"
 mimalloc = { version = "*", default-features = false }
diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs
index 5f5baaa34..d52211ff7 100644
--- a/dask_planner/src/sql.rs
+++ b/dask_planner/src/sql.rs
@@ -400,7 +400,7 @@ impl DaskSQLContext {
             Ok(valid) => {
                 if valid {
                     optimizer::DaskSqlOptimizer::new(true)
-                        .run_optimizations(existing_plan.original_plan)
+                        .optimize(existing_plan.original_plan)
                         .map(|k| PyLogicalPlan {
                             original_plan: k,
                             current_node: None,
diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs
index 17bea5343..565bd097d 100644
--- a/dask_planner/src/sql/logical.rs
+++ b/dask_planner/src/sql/logical.rs
@@ -299,6 +299,7 @@ impl PyLogicalPlan {
             LogicalPlan::CreateCatalogSchema(_create) => "CreateCatalogSchema",
             LogicalPlan::CreateCatalog(_create_catalog) => "CreateCatalog",
             LogicalPlan::CreateView(_create_view) => "CreateView",
+            LogicalPlan::SetVariable(_) => "SetVariable",
             // Further examine and return the name that is a possible Dask-SQL Extension type
             LogicalPlan::Extension(extension) => {
                 let node = extension.node.as_any();
diff --git a/dask_planner/src/sql/logical/join.rs b/dask_planner/src/sql/logical/join.rs
index 3ddcb757e..e9a0a6485 100644
--- a/dask_planner/src/sql/logical/join.rs
+++ b/dask_planner/src/sql/logical/join.rs
@@ -82,8 +82,10 @@ impl PyJoin {
             JoinType::Left => Ok("LEFT".to_string()),
             JoinType::Right => Ok("RIGHT".to_string()),
             JoinType::Full => Ok("FULL".to_string()),
-            JoinType::Semi => Ok("SEMI".to_string()),
-            JoinType::Anti => Ok("ANTI".to_string()),
+            JoinType::LeftSemi => Ok("LEFTSEMI".to_string()),
+            JoinType::LeftAnti => Ok("LEFTANTI".to_string()),
+            JoinType::RightSemi => Ok("RIGHTSEMI".to_string()),
+            JoinType::RightAnti => Ok("RIGHTANTI".to_string()),
         }
     }
 }
diff --git a/dask_planner/src/sql/logical/window.rs b/dask_planner/src/sql/logical/window.rs
index 91d7485ee..fd0b04196 100644
--- a/dask_planner/src/sql/logical/window.rs
+++ b/dask_planner/src/sql/logical/window.rs
@@ -1,7 +1,9 @@
+use datafusion_common::ScalarValue;
 use datafusion_expr::{logical_plan::Window, Expr, LogicalPlan, WindowFrame, WindowFrameBound};
 use pyo3::prelude::*;
 
 use crate::{
+    error::DaskPlannerError,
     expression::{py_expr_list, PyExpr},
     sql::exceptions::py_type_err,
 };
@@ -58,57 +60,45 @@ impl PyWindow {
     /// Returns order by columns in a window function expression
     #[pyo3(name = "getSortExprs")]
     pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult> {
-        match expr.expr {
+        match expr.expr.unalias() {
             Expr::WindowFunction { order_by, .. } => py_expr_list(&self.window.input, &order_by),
-            _ => Err(py_type_err(format!(
-                "Provided Expr {:?} is not a WindowFunction type",
-                expr
-            ))),
+            other => Err(not_window_function_err(other)),
         }
     }
 
     /// Return partition by columns in a window function expression
     #[pyo3(name = "getPartitionExprs")]
     pub fn get_partition_exprs(&self, expr: PyExpr) -> PyResult> {
-        match expr.expr {
+        match expr.expr.unalias() {
             Expr::WindowFunction { partition_by, .. } => {
                 py_expr_list(&self.window.input, &partition_by)
             }
-            _ => Err(py_type_err(format!(
-                "Provided Expr {:?} is not a WindowFunction type",
-                expr
-            ))),
+            other => Err(not_window_function_err(other)),
         }
     }
 
     /// Return input args for window function
     #[pyo3(name = "getArgs")]
     pub fn get_args(&self, expr: PyExpr) -> PyResult> {
-        match expr.expr {
+        match expr.expr.unalias() {
             Expr::WindowFunction { args, .. } => py_expr_list(&self.window.input, &args),
-            _ => Err(py_type_err(format!(
-                "Provided Expr {:?} is not a WindowFunction type",
-                expr
-            ))),
+            other => Err(not_window_function_err(other)),
         }
     }
 
     /// Return window function name
     #[pyo3(name = "getWindowFuncName")]
     pub fn window_func_name(&self, expr: PyExpr) -> PyResult {
-        match expr.expr {
+        match expr.expr.unalias() {
             Expr::WindowFunction { fun, .. } => Ok(fun.to_string()),
-            _ => Err(py_type_err(format!(
-                "Provided Expr {:?} is not a WindowFunction type",
-                expr
-            ))),
+            other => Err(not_window_function_err(other)),
         }
     }
 
     /// Returns a Pywindow frame for a given window function expression
     #[pyo3(name = "getWindowFrame")]
     pub fn get_window_frame(&self, expr: PyExpr) -> Option {
-        match expr.expr {
+        match expr.expr.unalias() {
             Expr::WindowFunction { window_frame, .. } => {
                 window_frame.map(|window_frame| window_frame.into())
             }
@@ -117,6 +107,14 @@ impl PyWindow {
     }
 }
 
+fn not_window_function_err(expr: Expr) -> PyErr {
+    py_type_err(format!(
+        "Provided {} Expr {:?} is not a WindowFunction type",
+        expr.variant_name(),
+        expr
+    ))
+}
+
 #[pymethods]
 impl PyWindowFrame {
     /// Returns the window frame units for the bounds
@@ -127,12 +125,12 @@ impl PyWindowFrame {
     /// Returns starting bound
     #[pyo3(name = "getLowerBound")]
     pub fn get_lower_bound(&self) -> PyResult {
-        Ok(self.window_frame.start_bound.into())
+        Ok(self.window_frame.start_bound.clone().into())
     }
     /// Returns end bound
     #[pyo3(name = "getUpperBound")]
     pub fn get_upper_bound(&self) -> PyResult {
-        Ok(self.window_frame.end_bound.into())
+        Ok(self.window_frame.end_bound.clone().into())
     }
 }
 
@@ -147,28 +145,38 @@ impl PyWindowFrameBound {
     /// Returns if the frame bound is preceding
     #[pyo3(name = "isPreceding")]
     pub fn is_preceding(&self) -> bool {
-        matches!(self.frame_bound, WindowFrameBound::Preceding(..))
+        matches!(self.frame_bound, WindowFrameBound::Preceding(_))
     }
 
     /// Returns if the frame bound is following
     #[pyo3(name = "isFollowing")]
     pub fn is_following(&self) -> bool {
-        matches!(self.frame_bound, WindowFrameBound::Following(..))
+        matches!(self.frame_bound, WindowFrameBound::Following(_))
     }
     /// Returns the offset of the window frame
     #[pyo3(name = "getOffset")]
-    pub fn get_offset(&self) -> Option {
+    pub fn get_offset(&self) -> PyResult> {
         match self.frame_bound {
-            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => val,
-            WindowFrameBound::CurrentRow => None,
+            WindowFrameBound::Preceding(ScalarValue::UInt64(val))
+            | WindowFrameBound::Following(ScalarValue::UInt64(val)) => Ok(val),
+            WindowFrameBound::Preceding(ref x) | WindowFrameBound::Following(ref x) => Err(
+                DaskPlannerError::Internal(format!("Unexpected window frame bound: {:?}", x))
+                    .into(),
+            ),
+            WindowFrameBound::CurrentRow => Ok(None),
         }
     }
     /// Returns if the frame bound is unbounded
     #[pyo3(name = "isUnbounded")]
-    pub fn is_unbounded(&self) -> bool {
-        match self.frame_bound {
-            WindowFrameBound::Preceding(val) | WindowFrameBound::Following(val) => val.is_none(),
-            WindowFrameBound::CurrentRow => false,
+    pub fn is_unbounded(&self) -> PyResult {
+        match &self.frame_bound {
+            WindowFrameBound::Preceding(ScalarValue::UInt64(v))
+            | WindowFrameBound::Following(ScalarValue::UInt64(v)) => Ok(v.is_none()),
+            WindowFrameBound::Preceding(ref x) | WindowFrameBound::Following(ref x) => Err(
+                DaskPlannerError::Internal(format!("Unexpected window frame bound: {:?}", x))
+                    .into(),
+            ),
+            WindowFrameBound::CurrentRow => Ok(false),
         }
     }
 }
diff --git a/dask_planner/src/sql/optimizer.rs b/dask_planner/src/sql/optimizer.rs
index 8a0c87c7e..8067e8a5e 100644
--- a/dask_planner/src/sql/optimizer.rs
+++ b/dask_planner/src/sql/optimizer.rs
@@ -1,15 +1,18 @@
+use std::sync::Arc;
+
 use datafusion_common::DataFusionError;
 use datafusion_expr::LogicalPlan;
 use datafusion_optimizer::{
     common_subexpr_eliminate::CommonSubexprEliminate,
     decorrelate_where_exists::DecorrelateWhereExists,
     decorrelate_where_in::DecorrelateWhereIn,
-    eliminate_filter::EliminateFilter,
+    // TODO: need to handle EmptyRelation for GPU cases
+    // eliminate_filter::EliminateFilter,
     eliminate_limit::EliminateLimit,
     filter_null_join_keys::FilterNullJoinKeys,
-    filter_push_down::FilterPushDown,
+    inline_table_scan::InlineTableScan,
     limit_push_down::LimitPushDown,
-    optimizer::OptimizerRule,
+    optimizer::{Optimizer, OptimizerRule},
     projection_push_down::ProjectionPushDown,
     reduce_cross_join::ReduceCrossJoin,
     reduce_outer_join::ReduceOuterJoin,
@@ -26,85 +29,73 @@ use log::trace;
 mod eliminate_agg_distinct;
 use eliminate_agg_distinct::EliminateAggDistinct;
 
+mod filter_push_down;
+use filter_push_down::FilterPushDown;
+
 /// Houses the optimization logic for Dask-SQL. This optimization controls the optimizations
 /// and their ordering in regards to their impact on the underlying `LogicalPlan` instance
 pub struct DaskSqlOptimizer {
     skip_failing_rules: bool,
-    optimizations: Vec>,
+    optimizer: Optimizer,
 }
 
 impl DaskSqlOptimizer {
     /// Creates a new instance of the DaskSqlOptimizer with all the DataFusion desired
     /// optimizers as well as any custom `OptimizerRule` trait impls that might be desired.
     pub fn new(skip_failing_rules: bool) -> Self {
-        let rules: Vec> = vec![
-            Box::new(TypeCoercion::new()),
-            Box::new(SimplifyExpressions::new()),
-            Box::new(UnwrapCastInComparison::new()),
-            Box::new(DecorrelateWhereExists::new()),
-            Box::new(DecorrelateWhereIn::new()),
-            Box::new(ScalarSubqueryToJoin::new()),
-            Box::new(SubqueryFilterToJoin::new()),
+        let rules: Vec> = vec![
+            Arc::new(InlineTableScan::new()),
+            Arc::new(TypeCoercion::new()),
+            Arc::new(SimplifyExpressions::new()),
+            Arc::new(UnwrapCastInComparison::new()),
+            Arc::new(DecorrelateWhereExists::new()),
+            Arc::new(DecorrelateWhereIn::new()),
+            Arc::new(ScalarSubqueryToJoin::new()),
+            Arc::new(SubqueryFilterToJoin::new()),
             // simplify expressions does not simplify expressions in subqueries, so we
             // run it again after running the optimizations that potentially converted
             // subqueries to joins
-            Box::new(SimplifyExpressions::new()),
-            Box::new(EliminateFilter::new()),
-            Box::new(ReduceCrossJoin::new()),
-            Box::new(CommonSubexprEliminate::new()),
-            Box::new(EliminateLimit::new()),
-            Box::new(RewriteDisjunctivePredicate::new()),
-            Box::new(FilterNullJoinKeys::default()),
-            Box::new(ReduceOuterJoin::new()),
-            Box::new(FilterPushDown::new()),
-            Box::new(LimitPushDown::new()),
-            // Box::new(SingleDistinctToGroupBy::new()),
+            Arc::new(SimplifyExpressions::new()),
+            // TODO: need to handle EmptyRelation for GPU cases
+            // Arc::new(EliminateFilter::new()),
+            Arc::new(ReduceCrossJoin::new()),
+            Arc::new(CommonSubexprEliminate::new()),
+            Arc::new(EliminateLimit::new()),
+            Arc::new(RewriteDisjunctivePredicate::new()),
+            Arc::new(FilterNullJoinKeys::default()),
+            Arc::new(ReduceOuterJoin::new()),
+            Arc::new(FilterPushDown::new()),
+            Arc::new(LimitPushDown::new()),
             // Dask-SQL specific optimizations
-            Box::new(EliminateAggDistinct::new()),
+            Arc::new(EliminateAggDistinct::new()),
             // The previous optimizations added expressions and projections,
             // that might benefit from the following rules
-            Box::new(SimplifyExpressions::new()),
-            Box::new(UnwrapCastInComparison::new()),
-            Box::new(CommonSubexprEliminate::new()),
-            Box::new(ProjectionPushDown::new()),
+            Arc::new(SimplifyExpressions::new()),
+            Arc::new(UnwrapCastInComparison::new()),
+            Arc::new(CommonSubexprEliminate::new()),
+            Arc::new(ProjectionPushDown::new()),
         ];
+
         Self {
             skip_failing_rules,
-            optimizations: rules,
+            optimizer: Optimizer::with_rules(rules),
         }
     }
 
     /// Iteratoes through the configured `OptimizerRule`(s) to transform the input `LogicalPlan`
     /// to its final optimized form
-    pub(crate) fn run_optimizations(
-        &self,
-        plan: LogicalPlan,
-    ) -> Result {
-        let mut resulting_plan: LogicalPlan = plan;
-        for optimization in &self.optimizations {
-            match optimization.optimize(&resulting_plan, &mut OptimizerConfig::new()) {
-                Ok(optimized_plan) => {
-                    trace!(
-                        "== AFTER APPLYING RULE {} ==\n{}",
-                        optimization.name(),
-                        optimized_plan.display_indent()
-                    );
-                    resulting_plan = optimized_plan
-                }
-                Err(e) => {
-                    if self.skip_failing_rules {
-                        println!(
-                            "Skipping optimizer rule {} due to unexpected error: {}",
-                            optimization.name(),
-                            e
-                        );
-                    } else {
-                        return Err(e);
-                    }
-                }
-            }
-        }
-        Ok(resulting_plan)
+    pub(crate) fn optimize(&self, plan: LogicalPlan) -> Result {
+        let mut config =
+            OptimizerConfig::default().with_skip_failing_rules(self.skip_failing_rules);
+        self.optimizer.optimize(&plan, &mut config, Self::observe)
+    }
+
+    fn observe(optimized_plan: &LogicalPlan, optimization: &dyn OptimizerRule) {
+        trace!(
+            "== AFTER APPLYING RULE {} ==\n{}\n",
+            optimization.name(),
+            optimized_plan.display_indent()
+        );
     }
 }
 
@@ -158,7 +149,7 @@ mod tests {
 
         // optimize the logical plan
         let optimizer = DaskSqlOptimizer::new(false);
-        optimizer.run_optimizations(plan)
+        optimizer.optimize(plan)
     }
 
     struct MySchemaProvider {}
diff --git a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs
index 8ec91b4fe..411e0a25a 100644
--- a/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs
+++ b/dask_planner/src/sql/optimizer/eliminate_agg_distinct.rs
@@ -457,7 +457,7 @@ mod tests {
     fn assert_fully_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
         let optimizer = DaskSqlOptimizer::new(false);
         let optimized_plan = optimizer
-            .run_optimizations(plan.clone())
+            .optimize(plan.clone())
             .expect("failed to optimize plan");
         let formatted_plan = format!("{}", optimized_plan.display_indent());
         assert_eq!(expected, formatted_plan);
diff --git a/dask_planner/src/sql/optimizer/filter_push_down.rs b/dask_planner/src/sql/optimizer/filter_push_down.rs
new file mode 100644
index 000000000..ac0429774
--- /dev/null
+++ b/dask_planner/src/sql/optimizer/filter_push_down.rs
@@ -0,0 +1,641 @@
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan
+
+use std::{
+    collections::{HashMap, HashSet},
+    iter::once,
+};
+
+use datafusion_common::{Column, DFSchema, DataFusionError, Result};
+use datafusion_expr::{
+    col,
+    expr_rewriter::{replace_col, ExprRewritable, ExprRewriter},
+    logical_plan::{
+        Aggregate,
+        CrossJoin,
+        Join,
+        JoinType,
+        Limit,
+        LogicalPlan,
+        Projection,
+        TableScan,
+        Union,
+    },
+    utils::{expr_to_columns, exprlist_to_columns, from_plan},
+    Expr,
+    TableProviderFilterPushDown,
+};
+use datafusion_optimizer::{utils, OptimizerConfig, OptimizerRule};
+
+/// Filter Push Down optimizer rule pushes filter clauses down the plan
+/// # Introduction
+/// A filter-commutative operation is an operation whose result of filter(op(data)) = op(filter(data)).
+/// An example of a filter-commutative operation is a projection; a counter-example is `limit`.
+///
+/// The filter-commutative property is column-specific. An aggregate grouped by A on SUM(B)
+/// can commute with a filter that depends on A only, but does not commute with a filter that depends
+/// on SUM(B).
+///
+/// This optimizer commutes filters with filter-commutative operations to push the filters
+/// the closest possible to the scans, re-writing the filter expressions by every
+/// projection that changes the filter's expression.
+///
+/// Filter: b Gt Int64(10)
+///     Projection: a AS b
+///
+/// is optimized to
+///
+/// Projection: a AS b
+///     Filter: a Gt Int64(10)  <--- changed from b to a
+///
+/// This performs a single pass through the plan. When it passes through a filter, it stores that filter,
+/// and when it reaches a node that does not commute with it, it adds the filter to that place.
+/// When it passes through a projection, it re-writes the filter's expression taking into account that projection.
+/// When multiple filters would have been written, it `AND` their expressions into a single expression.
+#[derive(Default)]
+pub struct FilterPushDown {}
+
+/// Filter predicate represented by tuple of expression and its columns
+type Predicate = (Expr, HashSet);
+
+/// Multiple filter predicates represented by tuple of expressions vector
+/// and corresponding expression columns vector
+type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet>);
+
+#[derive(Debug, Clone, Default)]
+struct State {
+    // (predicate, columns on the predicate)
+    filters: Vec,
+}
+
+impl State {
+    fn append_predicates(&mut self, predicates: Predicates) {
+        predicates
+            .0
+            .into_iter()
+            .zip(predicates.1)
+            .for_each(|(expr, cols)| self.filters.push((expr.clone(), cols.clone())))
+    }
+}
+
+/// returns all predicates in `state` that depend on any of `used_columns`
+/// or the ones that does not reference any columns (e.g. WHERE 1=1)
+fn get_predicates<'a>(state: &'a State, used_columns: &HashSet) -> Predicates<'a> {
+    state
+        .filters
+        .iter()
+        .filter(|(_, columns)| {
+            columns.is_empty()
+                || !columns
+                    .intersection(used_columns)
+                    .collect::>()
+                    .is_empty()
+        })
+        .map(|&(ref a, ref b)| (a, b))
+        .unzip()
+}
+
+/// Optimizes the plan
+fn push_down(state: &State, plan: &LogicalPlan) -> Result {
+    let new_inputs = plan
+        .inputs()
+        .iter()
+        .map(|input| optimize(input, state.clone()))
+        .collect::>>()?;
+
+    let expr = plan.expressions();
+    from_plan(plan, &expr, &new_inputs)
+}
+
+// remove all filters from `filters` that are in `predicate_columns`
+fn remove_filters(filters: &[Predicate], predicate_columns: &[&HashSet]) -> Vec {
+    filters
+        .iter()
+        .filter(|(_, columns)| !predicate_columns.contains(&columns))
+        .cloned()
+        .collect::>()
+}
+
+/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
+/// in `state` depend on the columns `used_columns`.
+fn issue_filters(
+    mut state: State,
+    used_columns: HashSet,
+    plan: &LogicalPlan,
+) -> Result {
+    let (predicates, predicate_columns) = get_predicates(&state, &used_columns);
+
+    if predicates.is_empty() {
+        // all filters can be pushed down => optimize inputs and return new plan
+        return push_down(&state, plan);
+    }
+
+    let plan = utils::add_filter(plan.clone(), &predicates)?;
+
+    state.filters = remove_filters(&state.filters, &predicate_columns);
+
+    // continue optimization over all input nodes by cloning the current state (i.e. each node is independent)
+    push_down(&state, &plan)
+}
+
+// For a given JOIN logical plan, determine whether each side of the join is preserved.
+// We say a join side is preserved if the join returns all or a subset of the rows from
+// the relevant side, such that each row of the output table directly maps to a row of
+// the preserved input table. If a table is not preserved, it can provide extra null rows.
+// That is, there may be rows in the output table that don't directly map to a row in the
+// input table.
+//
+// For example:
+//   - In an inner join, both sides are preserved, because each row of the output
+//     maps directly to a row from each side.
+//   - In a left join, the left side is preserved and the right is not, because
+//     there may be rows in the output that don't directly map to a row in the
+//     right input (due to nulls filling where there is no match on the right).
+//
+// This is important because we can always push down post-join filters to a preserved
+// side of the join, assuming the filter only references columns from that side. For the
+// non-preserved side it can be more tricky.
+//
+// Returns a tuple of booleans - (left_preserved, right_preserved).
+fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
+    match plan {
+        LogicalPlan::Join(Join { join_type, .. }) => match join_type {
+            JoinType::Inner => Ok((true, true)),
+            JoinType::Left => Ok((true, false)),
+            JoinType::Right => Ok((false, true)),
+            JoinType::Full => Ok((false, false)),
+            // No columns from the right side of the join can be referenced in output
+            // predicates for semi/anti joins, so whether we specify t/f doesn't matter.
+            JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
+            _ => todo!(),
+        },
+        LogicalPlan::CrossJoin(_) => Ok((true, true)),
+        _ => Err(DataFusionError::Internal(
+            "lr_is_preserved only valid for JOIN nodes".to_string(),
+        )),
+    }
+}
+
+// For a given JOIN logical plan, determine whether each side of the join is preserved
+// in terms on join filtering.
+// Predicates from join filter can only be pushed to preserved join side.
+fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
+    match plan {
+        LogicalPlan::Join(Join { join_type, .. }) => match join_type {
+            JoinType::Inner => Ok((true, true)),
+            JoinType::Left => Ok((false, true)),
+            JoinType::Right => Ok((true, false)),
+            JoinType::Full => Ok((false, false)),
+            JoinType::LeftSemi | JoinType::LeftAnti => {
+                // filter_push_down does not yet support SEMI/ANTI joins with join conditions
+                Ok((false, false))
+            }
+            _ => todo!(),
+        },
+        LogicalPlan::CrossJoin(_) => Err(DataFusionError::Internal(
+            "on_lr_is_preserved cannot be applied to CROSSJOIN nodes".to_string(),
+        )),
+        _ => Err(DataFusionError::Internal(
+            "on_lr_is_preserved only valid for JOIN nodes".to_string(),
+        )),
+    }
+}
+
+// Determine which predicates in state can be pushed down to a given side of a join.
+// To determine this, we need to know the schema of the relevant join side and whether
+// or not the side's rows are preserved when joining. If the side is not preserved, we
+// do not push down anything. Otherwise we can push down predicates where all of the
+// relevant columns are contained on the relevant join side's schema.
+fn get_pushable_join_predicates<'a>(
+    filters: &'a [Predicate],
+    schema: &DFSchema,
+    preserved: bool,
+) -> Predicates<'a> {
+    if !preserved {
+        return (vec![], vec![]);
+    }
+
+    let schema_columns = schema
+        .fields()
+        .iter()
+        .flat_map(|f| {
+            [
+                f.qualified_column(),
+                // we need to push down filter using unqualified column as well
+                f.unqualified_column(),
+            ]
+        })
+        .collect::>();
+
+    filters
+        .iter()
+        .filter(|(_, columns)| {
+            let all_columns_in_schema = schema_columns
+                .intersection(columns)
+                .collect::>()
+                .len()
+                == columns.len();
+            all_columns_in_schema
+        })
+        .map(|(a, b)| (a, b))
+        .unzip()
+}
+
+fn optimize_join(
+    mut state: State,
+    plan: &LogicalPlan,
+    left: &LogicalPlan,
+    right: &LogicalPlan,
+    on_filter: Vec,
+) -> Result {
+    // Get pushable predicates from current optimizer state
+    let (left_preserved, right_preserved) = lr_is_preserved(plan)?;
+    let to_left = get_pushable_join_predicates(&state.filters, left.schema(), left_preserved);
+    let to_right = get_pushable_join_predicates(&state.filters, right.schema(), right_preserved);
+    let to_keep: Predicates = state
+        .filters
+        .iter()
+        .filter(|(e, _)| !to_left.0.contains(&e) && !to_right.0.contains(&e))
+        .map(|(a, b)| (a, b))
+        .unzip();
+
+    // Get pushable predicates from join filter
+    let (on_to_left, on_to_right, on_to_keep) = if on_filter.is_empty() {
+        ((vec![], vec![]), (vec![], vec![]), vec![])
+    } else {
+        let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(plan)?;
+        let on_to_left = get_pushable_join_predicates(&on_filter, left.schema(), on_left_preserved);
+        let on_to_right =
+            get_pushable_join_predicates(&on_filter, right.schema(), on_right_preserved);
+        let on_to_keep = on_filter
+            .iter()
+            .filter(|(e, _)| !on_to_left.0.contains(&e) && !on_to_right.0.contains(&e))
+            .map(|(a, _)| a.clone())
+            .collect::>();
+
+        (on_to_left, on_to_right, on_to_keep)
+    };
+
+    // Build new filter states using pushable predicates
+    // from current optimizer states and from ON clause.
+    // Then recursively call optimization for both join inputs
+    let mut left_state = State { filters: vec![] };
+    left_state.append_predicates(to_left);
+    left_state.append_predicates(on_to_left);
+    let left = optimize(left, left_state)?;
+
+    let mut right_state = State { filters: vec![] };
+    right_state.append_predicates(to_right);
+    right_state.append_predicates(on_to_right);
+    let right = optimize(right, right_state)?;
+
+    // Create a new Join with the new `left` and `right`
+    //
+    // expressions() output for Join is a vector consisting of
+    //   1. join keys - columns mentioned in ON clause
+    //   2. optional predicate - in case join filter is not empty,
+    //      it always will be the last element, otherwise result
+    //      vector will contain only join keys (without additional
+    //      element representing filter).
+    let expr = plan.expressions();
+    let expr = if !on_filter.is_empty() && on_to_keep.is_empty() {
+        // New filter expression is None - should remove last element
+        expr[..expr.len() - 1].to_vec()
+    } else if !on_to_keep.is_empty() {
+        // Replace last element with new filter expression
+        expr[..expr.len() - 1]
+            .iter()
+            .cloned()
+            .chain(once(on_to_keep.into_iter().reduce(Expr::and).unwrap()))
+            .collect()
+    } else {
+        plan.expressions()
+    };
+    let plan = from_plan(plan, &expr, &[left, right])?;
+
+    if to_keep.0.is_empty() {
+        Ok(plan)
+    } else {
+        // wrap the join on the filter whose predicates must be kept
+        let plan = utils::add_filter(plan, &to_keep.0);
+        state.filters = remove_filters(&state.filters, &to_keep.1);
+        plan
+    }
+}
+
+fn optimize(plan: &LogicalPlan, mut state: State) -> Result {
+    match plan {
+        LogicalPlan::Explain { .. } => {
+            // push the optimization to the plan of this explain
+            push_down(&state, plan)
+        }
+        LogicalPlan::Analyze { .. } => push_down(&state, plan),
+        LogicalPlan::Filter(filter) => {
+            let predicates = utils::split_conjunction(filter.predicate());
+
+            predicates
+                .into_iter()
+                .try_for_each::<_, Result<()>>(|predicate| {
+                    let mut columns: HashSet = HashSet::new();
+                    expr_to_columns(predicate, &mut columns)?;
+                    state.filters.push((predicate.clone(), columns));
+                    Ok(())
+                })?;
+
+            optimize(filter.input(), state)
+        }
+        LogicalPlan::Projection(Projection {
+            input,
+            expr,
+            schema,
+            alias: _,
+        }) => {
+            // A projection is filter-commutable, but re-writes all predicate expressions
+            // collect projection.
+            let projection = schema
+                .fields()
+                .iter()
+                .enumerate()
+                .flat_map(|(i, field)| {
+                    // strip alias, as they should not be part of filters
+                    let expr = match &expr[i] {
+                        Expr::Alias(expr, _) => expr.as_ref().clone(),
+                        expr => expr.clone(),
+                    };
+
+                    // Convert both qualified and unqualified fields
+                    [
+                        (field.name().clone(), expr.clone()),
+                        (field.qualified_name(), expr),
+                    ]
+                })
+                .collect::>();
+
+            // re-write all filters based on this projection
+            // E.g. in `Filter: b\n  Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
+            for (predicate, columns) in state.filters.iter_mut() {
+                *predicate = replace_cols_by_name(predicate.clone(), &projection)?;
+
+                columns.clear();
+                expr_to_columns(predicate, columns)?;
+            }
+
+            // optimize inner
+            let new_input = optimize(input, state)?;
+            Ok(from_plan(plan, expr, &[new_input])?)
+        }
+        LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
+            // An aggregate's aggreagate columns are _not_ filter-commutable => collect these:
+            // * columns whose aggregation expression depends on
+            // * the aggregation columns themselves
+
+            // construct set of columns that `aggr_expr` depends on
+            let mut used_columns = HashSet::new();
+            exprlist_to_columns(aggr_expr, &mut used_columns)?;
+
+            let agg_columns = aggr_expr
+                .iter()
+                .map(|x| Ok(Column::from_name(x.display_name()?)))
+                .collect::>>()?;
+            used_columns.extend(agg_columns);
+
+            issue_filters(state, used_columns, plan)
+        }
+        LogicalPlan::Sort { .. } => {
+            // sort is filter-commutable
+            push_down(&state, plan)
+        }
+        LogicalPlan::Union(Union {
+            inputs: _,
+            schema,
+            alias: _,
+        }) => {
+            // union changing all qualifiers while building logical plan so we need
+            // to rewrite filters to push unqualified columns to inputs
+            let projection = schema
+                .fields()
+                .iter()
+                .map(|field| (field.qualified_name(), col(field.name())))
+                .collect::>();
+
+            // rewriting predicate expressions using unqualified names as replacements
+            if !projection.is_empty() {
+                for (predicate, columns) in state.filters.iter_mut() {
+                    *predicate = replace_cols_by_name(predicate.clone(), &projection)?;
+
+                    columns.clear();
+                    expr_to_columns(predicate, columns)?;
+                }
+            }
+
+            push_down(&state, plan)
+        }
+        LogicalPlan::Limit(Limit { input, .. }) => {
+            // limit is _not_ filter-commutable => collect all columns from its input
+            let used_columns = input
+                .schema()
+                .fields()
+                .iter()
+                .map(|f| f.qualified_column())
+                .collect::>();
+            issue_filters(state, used_columns, plan)
+        }
+        LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => {
+            optimize_join(state, plan, left, right, vec![])
+        }
+        LogicalPlan::Join(Join {
+            left,
+            right,
+            on,
+            filter,
+            join_type,
+            ..
+        }) => {
+            // Convert JOIN ON predicate to Predicates
+            let on_filters = filter
+                .as_ref()
+                .map(|e| {
+                    let predicates = utils::split_conjunction(e);
+
+                    predicates
+                        .into_iter()
+                        .map(|e| {
+                            let mut accum = HashSet::new();
+                            expr_to_columns(e, &mut accum)?;
+                            Ok((e.clone(), accum))
+                        })
+                        .collect::>>()
+                })
+                .unwrap_or_else(|| Ok(vec![]))?;
+
+            if *join_type == JoinType::Inner {
+                // For inner joins, duplicate filters for joined columns so filters can be pushed down
+                // to both sides. Take the following query as an example:
+                //
+                // ```sql
+                // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
+                // ```
+                //
+                // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
+                // `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
+                //
+                // Join clauses with `Using` constraints also take advantage of this logic to make sure
+                // predicates reference the shared join columns are pushed to both sides.
+                // This logic should also been applied to conditions in JOIN ON clause
+                let join_side_filters = state
+                    .filters
+                    .iter()
+                    .chain(on_filters.iter())
+                    .filter_map(|(predicate, columns)| {
+                        let mut join_cols_to_replace = HashMap::new();
+                        for col in columns.iter() {
+                            for (l, r) in on {
+                                if col == l {
+                                    join_cols_to_replace.insert(col, r);
+                                    break;
+                                } else if col == r {
+                                    join_cols_to_replace.insert(col, l);
+                                    break;
+                                }
+                            }
+                        }
+
+                        if join_cols_to_replace.is_empty() {
+                            return None;
+                        }
+
+                        let join_side_predicate =
+                            match replace_col(predicate.clone(), &join_cols_to_replace) {
+                                Ok(p) => p,
+                                Err(e) => {
+                                    return Some(Err(e));
+                                }
+                            };
+
+                        let join_side_columns = columns
+                            .clone()
+                            .into_iter()
+                            // replace keys in join_cols_to_replace with values in resulting column
+                            // set
+                            .filter(|c| !join_cols_to_replace.contains_key(c))
+                            .chain(join_cols_to_replace.iter().map(|(_, v)| (*v).clone()))
+                            .collect();
+
+                        Some(Ok((join_side_predicate, join_side_columns)))
+                    })
+                    .collect::>>()?;
+                state.filters.extend(join_side_filters);
+            }
+
+            optimize_join(state, plan, left, right, on_filters)
+        }
+        LogicalPlan::TableScan(TableScan {
+            source,
+            projected_schema,
+            filters,
+            projection,
+            table_name,
+            fetch,
+        }) => {
+            let mut used_columns = HashSet::new();
+            let mut new_filters = filters.clone();
+
+            for (filter_expr, cols) in &state.filters {
+                let (preserve_filter_node, add_to_provider) =
+                    match source.supports_filter_pushdown(filter_expr)? {
+                        TableProviderFilterPushDown::Unsupported => (true, false),
+                        TableProviderFilterPushDown::Inexact => (true, true),
+                        TableProviderFilterPushDown::Exact => (false, true),
+                    };
+
+                if preserve_filter_node {
+                    used_columns.extend(cols.clone());
+                }
+
+                if add_to_provider {
+                    // Don't add expression again if it's already present in
+                    // pushed down filters.
+                    if new_filters.contains(filter_expr) {
+                        continue;
+                    }
+                    new_filters.push(filter_expr.clone());
+                }
+            }
+
+            issue_filters(
+                state,
+                used_columns,
+                &LogicalPlan::TableScan(TableScan {
+                    source: source.clone(),
+                    projection: projection.clone(),
+                    projected_schema: projected_schema.clone(),
+                    table_name: table_name.clone(),
+                    filters: new_filters,
+                    fetch: *fetch,
+                }),
+            )
+        }
+        _ => {
+            // all other plans are _not_ filter-commutable
+            let used_columns = plan
+                .schema()
+                .fields()
+                .iter()
+                .map(|f| f.qualified_column())
+                .collect::>();
+            issue_filters(state, used_columns, plan)
+        }
+    }
+}
+
+impl OptimizerRule for FilterPushDown {
+    fn name(&self) -> &str {
+        "filter_push_down"
+    }
+
+    fn optimize(&self, plan: &LogicalPlan, _: &mut OptimizerConfig) -> Result {
+        optimize(plan, State::default())
+    }
+}
+
+impl FilterPushDown {
+    #[allow(missing_docs)]
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+/// replaces columns by its name on the projection.
+fn replace_cols_by_name(e: Expr, replace_map: &HashMap) -> Result {
+    struct ColumnReplacer<'a> {
+        replace_map: &'a HashMap,
+    }
+
+    impl<'a> ExprRewriter for ColumnReplacer<'a> {
+        fn mutate(&mut self, expr: Expr) -> Result {
+            if let Expr::Column(c) = &expr {
+                match self.replace_map.get(&c.flat_name()) {
+                    Some(new_c) => Ok(new_c.clone()),
+                    None => Ok(expr),
+                }
+            } else {
+                Ok(expr)
+            }
+        }
+    }
+
+    e.rewrite(&mut ColumnReplacer { replace_map })
+}
diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs
index fd30163e6..65a9b24da 100644
--- a/dask_planner/src/sql/types.rs
+++ b/dask_planner/src/sql/types.rs
@@ -339,7 +339,7 @@ impl SqlTypeName {
                 let tokens = tokenizer.tokenize().map_err(DaskPlannerError::from)?;
                 let mut parser = Parser::new(tokens, &dialect);
                 match parser.parse_data_type().map_err(DaskPlannerError::from)? {
-                    SQLType::Decimal(_, _) => Ok(SqlTypeName::DECIMAL),
+                    SQLType::Decimal(_) => Ok(SqlTypeName::DECIMAL),
                     SQLType::Binary(_) => Ok(SqlTypeName::BINARY),
                     SQLType::Varbinary(_) => Ok(SqlTypeName::VARBINARY),
                     SQLType::Varchar(_) | SQLType::Nvarchar(_) => Ok(SqlTypeName::VARCHAR),
diff --git a/dask_sql/physical/rel/logical/join.py b/dask_sql/physical/rel/logical/join.py
index e3a88cf25..fcebb166d 100644
--- a/dask_sql/physical/rel/logical/join.py
+++ b/dask_sql/physical/rel/logical/join.py
@@ -43,7 +43,7 @@ class DaskJoinPlugin(BaseRelPlugin):
         "LEFT": "left",
         "RIGHT": "right",
         "FULL": "outer",
-        "SEMI": "inner",  # TODO: Need research here! This is likely not a true inner join
+        "LEFTSEMI": "inner",  # TODO: Need research here! This is likely not a true inner join
     }
 
     def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
diff --git a/dask_sql/physical/utils/filter.py b/dask_sql/physical/utils/filter.py
index ddfacd6ab..5309289c4 100644
--- a/dask_sql/physical/utils/filter.py
+++ b/dask_sql/physical/utils/filter.py
@@ -91,7 +91,8 @@ def attempt_predicate_pushdown(ddf: dd.DataFrame) -> dd.DataFrame:
     try:
         return dsk.layers[name]._regenerate_collection(
             dsk,
-            new_kwargs={io_layer: {"filters": filters}},
+            # TODO: shouldn't need to specify index=False after dask#9661 is merged
+            new_kwargs={io_layer: {"filters": filters, "index": False}},
         )
     except ValueError as err:
         # Most-likely failed to apply filters in read_parquet.
diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py
index a556b025f..69b964514 100644
--- a/tests/integration/test_filter.py
+++ b/tests/integration/test_filter.py
@@ -1,11 +1,15 @@
+import dask
 import dask.dataframe as dd
 import pandas as pd
 import pytest
 from dask.utils_test import hlg_layer
+from packaging.version import parse as parseVersion
 
 from dask_sql._compat import INT_NAN_IMPLEMENTED
 from tests.utils import assert_eq
 
+DASK_GT_2022_4_2 = parseVersion(dask.__version__) >= parseVersion("2022.4.2")
+
 
 def test_filter(c, df):
     return_df = c.sql("SELECT * FROM df WHERE a < 2")
@@ -208,8 +212,10 @@ def test_predicate_pushdown(c, parquet_ddf, query, df_func, filters):
     df = parquet_ddf
     expected_df = df_func(df)
 
-    # TODO: divisions should be consistent when successfully doing predicate pushdown
-    assert_eq(return_df, expected_df, check_divisions=False)
+    # divisions aren't equal for older dask versions
+    assert_eq(
+        return_df, expected_df, check_index=False, check_divisions=DASK_GT_2022_4_2
+    )
 
 
 def test_filtered_csv(tmpdir, c):

From 8eb0230151cc1b061f7be1a029fee8edcbb357f7 Mon Sep 17 00:00:00 2001
From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com>
Date: Wed, 16 Nov 2022 11:55:44 -0800
Subject: [PATCH 18/19] Bump actions/checkout from 2 to 3 (#920)

Bumps [actions/checkout](https://github.com/actions/checkout) from 2 to 3.
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v2...v3)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] 

Signed-off-by: dependabot[bot] 
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Ayush Dattagupta 
---
 .github/workflows/conda.yml         | 2 +-
 .github/workflows/docker.yml        | 2 +-
 .github/workflows/rust.yml          | 2 +-
 .github/workflows/style.yml         | 2 +-
 .github/workflows/test-upstream.yml | 8 ++++----
 .github/workflows/test.yml          | 8 ++++----
 .github/workflows/update-gpuci.yml  | 2 +-
 7 files changed, 13 insertions(+), 13 deletions(-)

diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml
index 566895bdb..4836ef2ed 100644
--- a/.github/workflows/conda.yml
+++ b/.github/workflows/conda.yml
@@ -32,7 +32,7 @@ jobs:
     name: Build (and optionally upload) the conda nightly
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
         with:
           fetch-depth: 0
       - name: Set up Python
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 2ffaaf1bf..3d79aab64 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -20,7 +20,7 @@ jobs:
     if: github.repository == 'dask-contrib/dask-sql'
     steps:
       - name: Check out the repo
-        uses: actions/checkout@v2
+        uses: actions/checkout@v3
       - name: Set up QEMU
         uses: docker/setup-qemu-action@v1
       - name: Set up Docker Buildx
diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml
index d61be9d5f..0cbdabab5 100644
--- a/.github/workflows/rust.yml
+++ b/.github/workflows/rust.yml
@@ -23,7 +23,7 @@ jobs:
     outputs:
       triggered: ${{ steps.detect-trigger.outputs.trigger-found }}
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
         with:
           fetch-depth: 2
       - uses: xarray-contrib/ci-trigger@v1.1
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index a74a6f1eb..abd407fac 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -13,7 +13,7 @@ jobs:
     name: Run pre-commit hooks
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - uses: actions/setup-python@v2
       - uses: actions-rs/toolchain@v1
         with:
diff --git a/.github/workflows/test-upstream.yml b/.github/workflows/test-upstream.yml
index 79678cf59..bd682c114 100644
--- a/.github/workflows/test-upstream.yml
+++ b/.github/workflows/test-upstream.yml
@@ -46,7 +46,7 @@ jobs:
         os: [ubuntu-latest, windows-latest, macos-latest]
         python: ["3.8", "3.9", "3.10"]
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
         with:
           fetch-depth: 0 # Fetch all history for all branches and tags.
       - name: Set up Python
@@ -86,7 +86,7 @@ jobs:
     name: "Test upstream dev in a dask cluster"
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python
         uses: conda-incubator/setup-miniconda@v2.2.0
         with:
@@ -138,7 +138,7 @@ jobs:
     name: "Test importing with bare requirements and upstream dev"
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python
         uses: conda-incubator/setup-miniconda@v2.2.0
         with:
@@ -181,7 +181,7 @@ jobs:
       )
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Report failures
         uses: actions/github-script@v3
         with:
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 8b330acd5..624ec0022 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -24,7 +24,7 @@ jobs:
     outputs:
       triggered: ${{ steps.detect-trigger.outputs.trigger-found }}
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
         with:
           fetch-depth: 2
       - uses: xarray-contrib/ci-trigger@v1.1
@@ -44,7 +44,7 @@ jobs:
         os: [ubuntu-latest, windows-latest, macos-latest]
         python: ["3.8", "3.9", "3.10"]
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python
         uses: conda-incubator/setup-miniconda@v2.2.0
         with:
@@ -87,7 +87,7 @@ jobs:
     needs: [detect-ci-trigger]
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python
         uses: conda-incubator/setup-miniconda@v2.2.0
         with:
@@ -137,7 +137,7 @@ jobs:
     needs: [detect-ci-trigger]
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python
         uses: conda-incubator/setup-miniconda@v2.2.0
         with:
diff --git a/.github/workflows/update-gpuci.yml b/.github/workflows/update-gpuci.yml
index 275c59caa..62634c987 100644
--- a/.github/workflows/update-gpuci.yml
+++ b/.github/workflows/update-gpuci.yml
@@ -11,7 +11,7 @@ jobs:
     if: github.repository == 'dask-contrib/dask-sql'
 
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
 
       - name: Parse current axis YAML
         uses: the-coding-turtle/ga-yaml-parser@v0.1.1

From 4933d342b6f42cbeab4db4744236d2c6b913b0f8 Mon Sep 17 00:00:00 2001
From: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com>
Date: Wed, 16 Nov 2022 16:22:04 -0800
Subject: [PATCH 19/19] Support `to_timestamp` (#838)

* functionality and pytest

* style fix

* add format param

* lint

* remove quotes from result

* return date64 instead of str

* lint

* Apply suggestions from code review

Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com>

* add string input and test

* lint

* timestampadd parser test

* change Variadic to Exact

* rust test and pdlike

* fix rust test maybe?

* minor change

* fix rust test

* gpu test?

* edit gpu test

* try again

* dask_cudf

* try except to_cupy

* use dd and add scalar/string tests

* style fix

* pass most gpu tests?

* Update call.py

* Apply suggestions from code review

Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com>

* add pytest.mark.skip and comments for gpu tests

* update with Ayush's suggestions

* link to issue

* Update tests/integration/test_rex.py

Co-authored-by: Ayush Dattagupta 

* use np instead of datetime for scalars

* wrap str case in np.datetime64

Co-authored-by: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com>
Co-authored-by: Ayush Dattagupta 
---
 dask_planner/src/dialect.rs        |  78 ++++++++++++++++++-
 dask_planner/src/parser.rs         |  46 +++++++++++
 dask_planner/src/sql.rs            |  18 +++++
 dask_sql/physical/rex/core/call.py |  37 +++++++++
 tests/integration/test_rex.py      | 119 +++++++++++++++++++++++++++++
 5 files changed, 297 insertions(+), 1 deletion(-)

diff --git a/dask_planner/src/dialect.rs b/dask_planner/src/dialect.rs
index 973f76f4f..b27c81ec3 100644
--- a/dask_planner/src/dialect.rs
+++ b/dask_planner/src/dialect.rs
@@ -1,6 +1,11 @@
 use core::{iter::Peekable, str::Chars};
 
-use datafusion_sql::sqlparser::dialect::Dialect;
+use datafusion_sql::sqlparser::{
+    ast::{Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, Value},
+    dialect::Dialect,
+    parser::{Parser, ParserError},
+    tokenizer::Token,
+};
 
 #[derive(Debug)]
 pub struct DaskDialect {}
@@ -37,4 +42,75 @@ impl Dialect for DaskDialect {
     fn supports_filter_during_aggregation(&self) -> bool {
         true
     }
+
+    /// override expression parsing
+    fn parse_prefix(&self, parser: &mut Parser) -> Option> {
+        fn parse_expr(parser: &mut Parser) -> Result, ParserError> {
+            match parser.peek_token() {
+                Token::Word(w) if w.value.to_lowercase() == "timestampadd" => {
+                    // TIMESTAMPADD(YEAR, 2, d)
+                    parser.next_token(); // skip timestampadd
+                    parser.expect_token(&Token::LParen)?;
+                    let time_unit = parser.next_token();
+                    parser.expect_token(&Token::Comma)?;
+                    let n = parser.parse_expr()?;
+                    parser.expect_token(&Token::Comma)?;
+                    let expr = parser.parse_expr()?;
+                    parser.expect_token(&Token::RParen)?;
+
+                    // convert to function args
+                    let args = vec![
+                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
+                            Value::SingleQuotedString(time_unit.to_string()),
+                        ))),
+                        FunctionArg::Unnamed(FunctionArgExpr::Expr(n)),
+                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
+                    ];
+
+                    Ok(Some(Expr::Function(Function {
+                        name: ObjectName(vec![Ident::new("timestampadd")]),
+                        args,
+                        over: None,
+                        distinct: false,
+                        special: false,
+                    })))
+                }
+                Token::Word(w) if w.value.to_lowercase() == "to_timestamp" => {
+                    // TO_TIMESTAMP(d, "%d/%m/%Y")
+                    parser.next_token(); // skip to_timestamp
+                    parser.expect_token(&Token::LParen)?;
+                    let expr = parser.parse_expr()?;
+                    let comma = parser.consume_token(&Token::Comma);
+                    let time_format = if comma {
+                        parser.next_token().to_string()
+                    } else {
+                        "%Y-%m-%d %H:%M:%S".to_string()
+                    };
+                    parser.expect_token(&Token::RParen)?;
+
+                    // convert to function args
+                    let args = vec![
+                        FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)),
+                        FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
+                            Value::SingleQuotedString(time_format),
+                        ))),
+                    ];
+
+                    Ok(Some(Expr::Function(Function {
+                        name: ObjectName(vec![Ident::new("dsql_totimestamp")]),
+                        args,
+                        over: None,
+                        distinct: false,
+                        special: false,
+                    })))
+                }
+                _ => Ok(None),
+            }
+        }
+        match parse_expr(parser) {
+            Ok(Some(expr)) => Some(Ok(expr)),
+            Ok(None) => None,
+            Err(e) => Some(Err(e)),
+        }
+    }
 }
diff --git a/dask_planner/src/parser.rs b/dask_planner/src/parser.rs
index d743af901..61be0e1cd 100644
--- a/dask_planner/src/parser.rs
+++ b/dask_planner/src/parser.rs
@@ -1236,6 +1236,52 @@ impl<'a> DaskParser<'a> {
 mod test {
     use crate::parser::{DaskParser, DaskStatement};
 
+    #[test]
+    fn timestampadd() {
+        let sql = "SELECT TIMESTAMPADD(YEAR, 2, d) FROM t";
+        let statements = DaskParser::parse_sql(sql).unwrap();
+        assert_eq!(1, statements.len());
+        let actual = format!("{:?}", statements[0]);
+        let expected = "projection: [\
+        UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"timestampadd\", quote_style: None }]), \
+        args: [\
+        Unnamed(Expr(Value(SingleQuotedString(\"YEAR\")))), \
+        Unnamed(Expr(Value(Number(\"2\", false)))), \
+        Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None })))\
+        ], over: None, distinct: false, special: false }))\
+        ]";
+        assert!(actual.contains(expected));
+    }
+
+    #[test]
+    fn to_timestamp() {
+        let sql1 = "SELECT TO_TIMESTAMP(d) FROM t";
+        let statements1 = DaskParser::parse_sql(sql1).unwrap();
+        assert_eq!(1, statements1.len());
+        let actual1 = format!("{:?}", statements1[0]);
+        let expected1 = "projection: [\
+        UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"dsql_totimestamp\", quote_style: None }]), \
+        args: [\
+        Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None }))), \
+        Unnamed(Expr(Value(SingleQuotedString(\"%Y-%m-%d %H:%M:%S\"))))\
+        ], over: None, distinct: false, special: false }))\
+        ]";
+        assert!(actual1.contains(expected1));
+
+        let sql2 = "SELECT TO_TIMESTAMP(d, \"%d/%m/%Y\") FROM t";
+        let statements2 = DaskParser::parse_sql(sql2).unwrap();
+        assert_eq!(1, statements2.len());
+        let actual2 = format!("{:?}", statements2[0]);
+        let expected2 = "projection: [\
+        UnnamedExpr(Function(Function { name: ObjectName([Ident { value: \"dsql_totimestamp\", quote_style: None }]), \
+        args: [\
+        Unnamed(Expr(Identifier(Ident { value: \"d\", quote_style: None }))), \
+        Unnamed(Expr(Value(SingleQuotedString(\"\\\"%d/%m/%Y\\\"\"))))\
+        ], over: None, distinct: false, special: false }))\
+        ]";
+        assert!(actual2.contains(expected2));
+    }
+
     #[test]
     fn create_model() {
         let sql = r#"CREATE MODEL my_model WITH (
diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs
index d52211ff7..bf6ce16ab 100644
--- a/dask_planner/src/sql.rs
+++ b/dask_planner/src/sql.rs
@@ -152,6 +152,24 @@ impl ContextProvider for DaskSQLContext {
                 let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
                 return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
             }
+            "dsql_totimestamp" => {
+                let sig = Signature::one_of(
+                    vec![
+                        TypeSignature::Exact(vec![DataType::Int8, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::Int16, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::Int32, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::UInt8, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::UInt16, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::UInt32, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::UInt64, DataType::Utf8]),
+                        TypeSignature::Exact(vec![DataType::Utf8, DataType::Utf8]),
+                    ],
+                    Volatility::Immutable,
+                );
+                let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Date64)));
+                return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
+            }
             "mod" => {
                 let sig = generate_numeric_signatures(2);
                 let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py
index a66b178dc..6a5b01c17 100644
--- a/dask_sql/physical/rex/core/call.py
+++ b/dask_sql/physical/rex/core/call.py
@@ -1,6 +1,7 @@
 import logging
 import operator
 import re
+from datetime import datetime
 from functools import partial, reduce
 from typing import TYPE_CHECKING, Any, Callable, Union
 
@@ -613,6 +614,41 @@ def extract(self, what, df: SeriesOrScalar):
             raise NotImplementedError(f"Extraction of {what} is not (yet) implemented.")
 
 
+class ToTimestampOperation(Operation):
+    def __init__(self):
+        super().__init__(self.to_timestamp)
+
+    def to_timestamp(self, df, format):
+        default_format = "%Y-%m-%d %H:%M:%S"
+        # Remove double and single quotes from string
+        format = format.replace('"', "")
+        format = format.replace("'", "")
+
+        # TODO: format timestamps for GPU tests
+        if "cudf" in str(type(df)):
+            if format != default_format:
+                raise RuntimeError("Non-default timestamp formats not supported on GPU")
+            if df.dtype == "object":
+                return df
+            else:
+                nanoseconds_to_seconds = 10**9
+                return df * nanoseconds_to_seconds
+        # String cases
+        elif type(df) == str:
+            return np.datetime64(datetime.strptime(df, format))
+        elif df.dtype == "object":
+            return dd.to_datetime(df, format=format)
+        # Integer cases
+        elif np.isscalar(df):
+            if format != default_format:
+                raise RuntimeError("Integer input does not accept a format argument")
+            return np.datetime64(int(df), "s")
+        else:
+            if format != default_format:
+                raise RuntimeError("Integer input does not accept a format argument")
+            return dd.to_datetime(df, unit="s")
+
+
 class YearOperation(Operation):
     def __init__(self):
         super().__init__(self.extract_year)
@@ -990,6 +1026,7 @@ class RexCallPlugin(BaseRexPlugin):
             lambda x: x + pd.tseries.offsets.MonthEnd(1),
             lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1),
         ),
+        "dsql_totimestamp": ToTimestampOperation(),
         # Temporary UDF functions that need to be moved after this POC
         "datepart": DatePartOperation(),
         "year": YearOperation(),
diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py
index b7d455fe3..510bf953b 100644
--- a/tests/integration/test_rex.py
+++ b/tests/integration/test_rex.py
@@ -677,3 +677,122 @@ def test_date_functions(c):
             FROM df
             """
         )
+
+
+@pytest.mark.parametrize(
+    "gpu",
+    [
+        False,
+        pytest.param(
+            True,
+            marks=(
+                pytest.mark.gpu,
+                pytest.mark.xfail(
+                    reason="Failing due to dask-cudf bug https://github.com/rapidsai/cudf/issues/12062"
+                ),
+            ),
+        ),
+    ],
+)
+def test_totimestamp(c, gpu):
+    df = pd.DataFrame(
+        {
+            "a": np.array([1203073300, 1406073600, 2806073600]),
+        }
+    )
+    c.create_table("df", df, gpu=gpu)
+
+    df = c.sql(
+        """
+        SELECT to_timestamp(a) AS date FROM df
+    """
+    )
+    expected_df = pd.DataFrame(
+        {
+            "date": [
+                datetime(2008, 2, 15, 11, 1, 40),
+                datetime(2014, 7, 23),
+                datetime(2058, 12, 2, 16, 53, 20),
+            ],
+        }
+    )
+    assert_eq(df, expected_df, check_dtype=False)
+
+    df = pd.DataFrame(
+        {
+            "a": np.array(["1997-02-28 10:30:00", "1997-03-28 10:30:01"]),
+        }
+    )
+    c.create_table("df", df, gpu=gpu)
+
+    df = c.sql(
+        """
+        SELECT to_timestamp(a) AS date FROM df
+    """
+    )
+    expected_df = pd.DataFrame(
+        {
+            "date": [
+                datetime(1997, 2, 28, 10, 30, 0),
+                datetime(1997, 3, 28, 10, 30, 1),
+            ],
+        }
+    )
+    assert_eq(df, expected_df, check_dtype=False)
+
+    df = pd.DataFrame(
+        {
+            "a": np.array(["02/28/1997", "03/28/1997"]),
+        }
+    )
+    c.create_table("df", df, gpu=gpu)
+
+    df = c.sql(
+        """
+        SELECT to_timestamp(a, "%m/%d/%Y") AS date FROM df
+    """
+    )
+    expected_df = pd.DataFrame(
+        {
+            "date": [
+                datetime(1997, 2, 28, 0, 0, 0),
+                datetime(1997, 3, 28, 0, 0, 0),
+            ],
+        }
+    )
+    # https://github.com/rapidsai/cudf/issues/12062
+    if not gpu:
+        assert_eq(df, expected_df, check_dtype=False)
+
+    int_input = 1203073300
+    df = c.sql(f"SELECT to_timestamp({int_input}) as date")
+    expected_df = pd.DataFrame(
+        {
+            "date": [
+                datetime(2008, 2, 15, 11, 1, 40),
+            ],
+        }
+    )
+    assert_eq(df, expected_df, check_dtype=False)
+
+    string_input = "1997-02-28 10:30:00"
+    df = c.sql(f"SELECT to_timestamp('{string_input}') as date")
+    expected_df = pd.DataFrame(
+        {
+            "date": [
+                datetime(1997, 2, 28, 10, 30, 0),
+            ],
+        }
+    )
+    assert_eq(df, expected_df, check_dtype=False)
+
+    string_input = "02/28/1997"
+    df = c.sql(f"SELECT to_timestamp('{string_input}', '%m/%d/%Y') as date")
+    expected_df = pd.DataFrame(
+        {
+            "date": [
+                datetime(1997, 2, 28, 0, 0, 0),
+            ],
+        }
+    )
+    assert_eq(df, expected_df, check_dtype=False)