Skip to content

Commit

Permalink
Tempfix
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 3, 2024
1 parent 9bc2ab0 commit afc4b8e
Show file tree
Hide file tree
Showing 40 changed files with 234 additions and 57 deletions.
3 changes: 2 additions & 1 deletion python/pyspark/ml/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def array_to_vector(col: Column) -> Column:


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
Expand All @@ -54,7 +55,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("ml.connect.functions tests")
.remote("local[4]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
import os

from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_classification import ClassificationTestsMixin

have_torch = True
# TODO(SPARK-48083): Reenable this test case
have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ
try:
import torch # noqa: F401
except ImportError:
Expand All @@ -31,7 +33,7 @@
class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]"))
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)
Expand Down
10 changes: 7 additions & 3 deletions python/pyspark/ml/tests/connect/test_connect_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import unittest

from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_evaluation import EvaluationTestsMixin

have_torcheval = True
# TODO(SPARK-48084): Reenable this test case
have_torcheval = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ
try:
import torcheval # noqa: F401
except ImportError:
Expand All @@ -29,7 +31,9 @@
@unittest.skipIf(not have_torcheval, "torcheval is required")
class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
Expand Down
7 changes: 5 additions & 2 deletions python/pyspark/ml/tests/connect/test_connect_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import unittest

from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin


class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/ml/tests/connect/test_connect_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pyspark.ml.connect import functions as CF


@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
class SparkConnectMLFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, SQLTestUtils):
"""These test cases exercise the interface to the proto plan
generation but do not call Spark."""
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/ml/tests/connect/test_connect_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import unittest

from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin


class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]"))
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)
Expand Down
7 changes: 5 additions & 2 deletions python/pyspark/ml/tests/connect/test_connect_summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import unittest

from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_summarizer import SummarizerTestsMixin


class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = SparkSession.builder.remote("local[2]").getOrCreate()
self.spark = SparkSession.builder.remote(
os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
).getOrCreate()

def tearDown(self) -> None:
self.spark.stop()
Expand Down
5 changes: 3 additions & 2 deletions python/pyspark/ml/tests/connect/test_connect_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import unittest
from pyspark.sql import SparkSession
from pyspark.ml.tests.connect.test_legacy_mode_tuning import CrossValidatorTestsMixin


@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
SparkSession.builder.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]"))
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# limitations under the License.
#

import os
import unittest
from pyspark.sql import SparkSession

have_torch = True
have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ
try:
import torch # noqa: F401
except ImportError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import shutil
import unittest

have_torch = True
have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ
try:
import torch # noqa: F401
except ImportError:
Expand Down Expand Up @@ -81,7 +81,7 @@ def _get_inputs_for_test_local_training_succeeds(self):
]


@unittest.skipIf(not have_torch, "torch is required")
@unittest.skipIf("SPARK_SKIP_CONNECT_COMPAT_TESTS" in os.environ, "Requires JVM access")
class TorchDistributorLocalUnitTestsIIOnConnect(
TorchDistributorLocalUnitTestsMixin, unittest.TestCase
):
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/torch/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# limitations under the License.
#

import os
import numpy as np
import unittest

have_torch = True
have_torch = "SPARK_SKIP_CONNECT_COMPAT_TESTS" not in os.environ
try:
import torch # noqa: F401
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.avro.functions tests")
.remote("local[4]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,14 +326,17 @@ def registerFunction(


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.catalog

globs = pyspark.sql.connect.catalog.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.catalog tests").remote("local[4]").getOrCreate()
PySparkSession.builder.appName("sql.connect.catalog tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,17 @@ def __nonzero__(self) -> None:


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.column

globs = pyspark.sql.connect.column.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.column tests").remote("local[4]").getOrCreate()
PySparkSession.builder.appName("sql.connect.column tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,17 @@ def _checkType(self, obj: Any, identifier: str) -> None:


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.conf

globs = pyspark.sql.connect.conf.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.conf tests").remote("local[4]").getOrCreate()
PySparkSession.builder.appName("sql.connect.conf tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,7 +2150,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
.remote("local[4]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3906,6 +3906,7 @@ def call_function(funcName: str, *cols: "ColumnOrName") -> Column:

def _test() -> None:
import sys
import os
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.functions
Expand All @@ -3914,7 +3915,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.functions tests")
.remote("local[4]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def _extract_cols(gd: "GroupedData") -> List[Column]:


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
Expand All @@ -396,7 +397,9 @@ def _test() -> None:
globs = pyspark.sql.connect.group.__dict__.copy()

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.group tests").remote("local[4]").getOrCreate()
PySparkSession.builder.appName("sql.connect.group tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

(failure_count, test_count) = doctest.testmod(
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/protobuf/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("sql.protobuf.functions tests")
.remote("local[2]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]"))
.getOrCreate()
)

Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ def overwritePartitions(self) -> None:

def _test() -> None:
import sys
import os
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.readwriter
Expand All @@ -838,7 +839,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.readwriter tests")
.remote("local[4]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,14 +910,17 @@ def session_id(self) -> str:


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
import pyspark.sql.connect.session

globs = pyspark.sql.connect.session.__dict__.copy()
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.session tests").remote("local[4]").getOrCreate()
PySparkSession.builder.appName("sql.connect.session tests")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

# Uses PySpark session to test builder.
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/streaming/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.streaming.query tests")
.remote("local[4]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/connect/streaming/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def toTable(


def _test() -> None:
import os
import sys
import doctest
from pyspark.sql import SparkSession as PySparkSession
Expand All @@ -595,7 +596,7 @@ def _test() -> None:

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.streaming.readwriter tests")
.remote("local[4]")
.remote(os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[4]"))
.getOrCreate()
)

Expand Down
Loading

0 comments on commit afc4b8e

Please sign in to comment.