diff --git a/pom.xml b/pom.xml
index 7ba2f364..3390bf72 100644
--- a/pom.xml
+++ b/pom.xml
@@ -34,9 +34,9 @@
1.8
1.8
UTF-8
- 2.12.18
- 2.12
- 4.2.0
+ 2.13.13
+ 2.13
+ 4.20.5
3.16.0
${scala.compat.version}
Snowpark ${project.version}
@@ -134,7 +134,7 @@
org.scalatest
scalatest_${scala.compat.version}
- 3.0.5
+ 3.2.18
test
@@ -496,7 +496,7 @@
test-coverage
- 2.12.15
+ 2.13.13
diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala
index 91f0021b..107a06e4 100644
--- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala
+++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala
@@ -1797,7 +1797,10 @@ class DataFrame private[snowpark] (
* @param firstArg The first argument to pass to the specified table function.
* @param remaining A list of any additional arguments for the specified table function.
*/
- def join(func: TableFunction, firstArg: Column, remaining: Column*): DataFrame =
+ def join(
+ func: com.snowflake.snowpark.TableFunction,
+ firstArg: Column,
+ remaining: Column*): DataFrame =
join(func, firstArg +: remaining)
/**
@@ -1824,7 +1827,7 @@ class DataFrame private[snowpark] (
* object or an object that you create from the [[TableFunction]] class.
* @param args A list of arguments to pass to the specified table function.
*/
- def join(func: TableFunction, args: Seq[Column]): DataFrame =
+ def join(func: com.snowflake.snowpark.TableFunction, args: Seq[Column]): DataFrame =
joinTableFunction(func.call(args: _*), None)
/**
@@ -1854,7 +1857,7 @@ class DataFrame private[snowpark] (
* @param orderBy A list of columns ordered by.
*/
def join(
- func: TableFunction,
+ func: com.snowflake.snowpark.TableFunction,
args: Seq[Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame =
@@ -1892,7 +1895,7 @@ class DataFrame private[snowpark] (
* Some functions, like `flatten`, have named parameters.
* Use this map to specify the parameter names and their corresponding values.
*/
- def join(func: TableFunction, args: Map[String, Column]): DataFrame =
+ def join(func: com.snowflake.snowpark.TableFunction, args: Map[String, Column]): DataFrame =
joinTableFunction(func.call(args), None)
/**
@@ -1929,7 +1932,7 @@ class DataFrame private[snowpark] (
* @param orderBy A list of columns ordered by.
*/
def join(
- func: TableFunction,
+ func: com.snowflake.snowpark.TableFunction,
args: Map[String, Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame =
@@ -2366,7 +2369,7 @@ class DataFrame private[snowpark] (
}
}
lines.append(value.substring(startIndex))
- lines
+ lines.toSeq
}
def convertValueToString(value: Any): String =
@@ -2501,7 +2504,7 @@ class DataFrame private[snowpark] (
* and view name.
*/
def createOrReplaceView(multipartIdentifier: java.util.List[String]): Unit =
- createOrReplaceView(multipartIdentifier.asScala)
+ createOrReplaceView(multipartIdentifier.asScala.toSeq)
/**
* Creates a temporary view that returns the same results as this DataFrame.
@@ -2564,7 +2567,7 @@ class DataFrame private[snowpark] (
* view name.
*/
def createOrReplaceTempView(multipartIdentifier: java.util.List[String]): Unit =
- createOrReplaceTempView(multipartIdentifier.asScala)
+ createOrReplaceTempView(multipartIdentifier.asScala.toSeq)
private def doCreateOrReplaceView(viewName: String, viewType: ViewType): Unit = {
session.conn.telemetry.reportActionCreateOrReplaceView()
diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala
index a1dc5aef..8617eb7b 100644
--- a/src/main/scala/com/snowflake/snowpark/Row.scala
+++ b/src/main/scala/com/snowflake/snowpark/Row.scala
@@ -350,7 +350,7 @@ class Row protected (values: Array[Any]) extends Serializable {
* @since 1.13.0
* @group getter
*/
- def getSeq[T](index: Int): Seq[T] = {
+ def getSeq[T: ClassTag](index: Int): Seq[T] = {
val result = getAs[Array[_]](index)
result.map {
case x: T => x
diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala
index 633c8e42..b590b9a5 100644
--- a/src/main/scala/com/snowflake/snowpark/Session.scala
+++ b/src/main/scala/com/snowflake/snowpark/Session.scala
@@ -444,7 +444,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* @since 0.2.0
*/
def table(multipartIdentifier: java.util.List[String]): Updatable =
- table(multipartIdentifier.asScala)
+ table(multipartIdentifier.asScala.toSeq)
/**
* Returns an Updatable that points to the specified table.
@@ -497,7 +497,10 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* @param firstArg the first function argument of the given table function.
* @param remaining all remaining function arguments.
*/
- def tableFunction(func: TableFunction, firstArg: Column, remaining: Column*): DataFrame =
+ def tableFunction(
+ func: com.snowflake.snowpark.TableFunction,
+ firstArg: Column,
+ remaining: Column*): DataFrame =
tableFunction(func, firstArg +: remaining)
/**
@@ -525,7 +528,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* referred from the built-in list from tableFunctions.
* @param args function arguments of the given table function.
*/
- def tableFunction(func: TableFunction, args: Seq[Column]): DataFrame = {
+ def tableFunction(func: com.snowflake.snowpark.TableFunction, args: Seq[Column]): DataFrame = {
// Use df.join to apply function result if args contains a DF column
val sourceDFs = args.flatMap(_.expr.sourceDFs)
if (sourceDFs.isEmpty) {
@@ -569,7 +572,9 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* Some functions, like flatten, have named parameters.
* use this map to assign values to the corresponding parameters.
*/
- def tableFunction(func: TableFunction, args: Map[String, Column]): DataFrame = {
+ def tableFunction(
+ func: com.snowflake.snowpark.TableFunction,
+ args: Map[String, Column]): DataFrame = {
// Use df.join to apply function result if args contains a DF column
val sourceDFs = args.values.flatMap(_.expr.sourceDFs)
if (sourceDFs.isEmpty) {
@@ -615,9 +620,9 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
def tableFunction(func: Column): DataFrame = {
func.expr match {
case TFunction(funcName, args) =>
- tableFunction(TableFunction(funcName), args.map(Column(_)))
+ tableFunction(com.snowflake.snowpark.TableFunction(funcName), args.map(Column(_)))
case NamedArgumentsTableFunction(funcName, argMap) =>
- tableFunction(TableFunction(funcName), argMap.map {
+ tableFunction(com.snowflake.snowpark.TableFunction(funcName), argMap.map {
case (key, value) => key -> Column(value)
})
case _ => throw ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT()
diff --git a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
index 58bd69e7..d18d275c 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
@@ -319,7 +319,7 @@ object JavaUtils {
def stringArrayToStringSeq(arr: Array[String]): Seq[String] = arr
def objectListToAnySeq(input: java.util.List[java.util.List[Object]]): Seq[Seq[Any]] =
- input.asScala.map(list => list.asScala)
+ input.asScala.map(list => list.asScala.toSeq).toSeq
def registerUDF(
udfRegistration: UDFRegistration,
@@ -360,18 +360,26 @@ object JavaUtils {
map.asScala.toMap
def javaMapToScalaWithVariantConversion(map: java.util.Map[_, _]): Map[Any, Any] =
- map.asScala.map {
- case (key, value: com.snowflake.snowpark_java.types.Variant) =>
- key -> InternalUtils.toScalaVariant(value)
- case (key, value) => key -> value
- }.toMap
+ map.asScala
+ .map((e: (Any, Any)) => {
+ (e) match {
+ case (key, value: com.snowflake.snowpark_java.types.Variant) =>
+ key -> InternalUtils.toScalaVariant(value)
+ case (key, value) => key -> value
+ }
+ })
+ .toMap
def scalaMapToJavaWithVariantConversion(map: Map[_, _]): java.util.Map[Object, Object] =
- map.map {
- case (key, value: com.snowflake.snowpark.types.Variant) =>
- key.asInstanceOf[Object] -> InternalUtils.createVariant(value)
- case (key, value) => key.asInstanceOf[Object] -> value.asInstanceOf[Object]
- }.asJava
+ map
+ .map((e: (Any, Any)) => {
+ (e) match {
+ case (key, value: com.snowflake.snowpark.types.Variant) =>
+ key.asInstanceOf[Object] -> InternalUtils.createVariant(value)
+ case (key, value) => key.asInstanceOf[Object] -> value.asInstanceOf[Object]
+ }
+ })
+ .asJava
def serialize(obj: Any): Array[Byte] = {
val bos = new ByteArrayOutputStream()
diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala
index 92728eaf..e920ba31 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala
@@ -52,6 +52,7 @@ import net.snowflake.client.jdbc.internal.apache.arrow.vector.util.{
import java.util
import scala.collection.mutable
+import scala.reflect.classTag
import scala.reflect.runtime.universe.TypeTag
import scala.collection.JavaConverters._
@@ -317,7 +318,7 @@ private[snowpark] class ServerConnection(
private[snowpark] def resultSetToRows(statement: Statement): Array[Row] = withValidConnection {
val iterator = resultSetToIterator(statement)._1
- val buff = mutable.ArrayBuilder.make[Row]()
+ val buff = mutable.ArrayBuilder.make[Row](classTag[Row])
while (iterator.hasNext) {
buff += iterator.next()
}
diff --git a/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala b/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala
index acf9b62e..ed8a9a36 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala
@@ -385,7 +385,7 @@ class UDXRegistrationHandler(session: Session) extends Logging {
if (actionID <= session.getLastCanceledID) {
throw ErrorMessage.MISC_QUERY_IS_CANCELLED()
}
- allUrls
+ allUrls.toSeq
}
(allImports, targetJarStageLocation)
}
diff --git a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala
index d464566c..4415afd0 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala
@@ -27,9 +27,9 @@ object Utils extends Logging {
// Define the compat scala version instead of reading from property file
// because it fails to read the property file in some environment such as
// VSCode worksheet.
- val ScalaCompatVersion: String = "2.12"
+ val ScalaCompatVersion: String = "2.13"
// Minimum minor version. We require version to be greater than 2.12.9
- val ScalaMinimumMinorVersion: String = "2.12.9"
+ val ScalaMinimumMinorVersion: String = "2.13.9"
// Minimum GS version for us to identify as Snowpark client
val MinimumGSVersionForSnowparkClientType: String = "5.20.0"
diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala
index a3218758..381d079e 100644
--- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala
+++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala
@@ -74,9 +74,9 @@ class SnowflakePlan(
}
val supportAsyncMode = subqueryPlans.forall(_.supportAsyncMode)
SnowflakePlan(
- preQueries :+ queries.last,
+ (preQueries :+ queries.last).toSeq,
newSchemaQuery,
- newPostActions,
+ newPostActions.toSeq,
session,
sourcePlan,
supportAsyncMode)
diff --git a/src/main/scala/com/snowflake/snowpark/types/Variant.scala b/src/main/scala/com/snowflake/snowpark/types/Variant.scala
index 28ce5b67..4f00f302 100644
--- a/src/main/scala/com/snowflake/snowpark/types/Variant.scala
+++ b/src/main/scala/com/snowflake/snowpark/types/Variant.scala
@@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.node.{ArrayNode, JsonNodeFactory, ObjectNo
import java.math.{BigDecimal => JavaBigDecimal, BigInteger => JavaBigInteger}
import java.sql.{Date, Time, Timestamp}
import java.util.{List => JavaList, Map => JavaMap}
+import scala.jdk.FunctionConverters._
import scala.collection.JavaConverters._
import Variant._
import org.apache.commons.codec.binary.{Base64, Hex}
@@ -225,7 +226,7 @@ class Variant private[snowpark] (
*
* @since 0.2.0
*/
- def this(list: JavaList[Object]) = this(list.asScala)
+ def this(list: JavaList[Object]) = this(list.asScala.toSeq)
/**
* Creates a Variant from array
@@ -244,7 +245,11 @@ class Variant private[snowpark] (
{
def mapToNode(map: JavaMap[Object, Object]): ObjectNode = {
val result = MAPPER.createObjectNode()
- map.keySet().forEach(key => result.set(key.toString, objectToJsonNode(map.get(key))))
+ val consumer =
+ (key: Object) => result.set(key.toString, objectToJsonNode(map.get(key)))
+ map
+ .keySet()
+ .forEach(consumer.asJavaConsumer)
result
}
obj match {
diff --git a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala
index c0f1ed3e..85bc310d 100644
--- a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala
+++ b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala
@@ -1,7 +1,7 @@
package com.snowflake.code_verification
import com.snowflake.snowpark.{CodeVerification, DataFrame}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
// verify API Java and Scala API contain same functions
@CodeVerification
diff --git a/src/test/scala/com/snowflake/code_verification/PomSuite.scala b/src/test/scala/com/snowflake/code_verification/PomSuite.scala
index 6dca6810..74d6d331 100644
--- a/src/test/scala/com/snowflake/code_verification/PomSuite.scala
+++ b/src/test/scala/com/snowflake/code_verification/PomSuite.scala
@@ -1,7 +1,7 @@
package com.snowflake.code_verification
import com.snowflake.snowpark.CodeVerification
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import scala.collection.mutable
diff --git a/src/test/scala/com/snowflake/perf/PerfBase.scala b/src/test/scala/com/snowflake/perf/PerfBase.scala
index 6969249a..508788d5 100644
--- a/src/test/scala/com/snowflake/perf/PerfBase.scala
+++ b/src/test/scala/com/snowflake/perf/PerfBase.scala
@@ -77,6 +77,7 @@ trait PerfBase extends SNTestBase {
test(testName) {
try {
writeResult(testName, timer(func))
+ succeed
} catch {
case ex: Exception =>
writeResult(testName, -1.0) // -1.0 if failed
@@ -84,7 +85,10 @@ trait PerfBase extends SNTestBase {
}
}
} else {
- ignore(testName)(func)
+ ignore(testName) {
+ func
+ succeed
+ }
}
}
}
diff --git a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala
index a28d167a..37bac0d4 100644
--- a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala
@@ -18,6 +18,7 @@ import com.snowflake.snowpark.internal.analyzer.{
import com.snowflake.snowpark.types._
import net.snowflake.client.core.SFSessionProperty
import net.snowflake.client.jdbc.SnowflakeSQLException
+import org.scalatest.Assertion
import java.nio.file.Files
import java.sql.{Date, Timestamp}
@@ -26,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.util.Random
+import scala.language.postfixOps
class APIInternalSuite extends TestData {
private val userSchema: StructType = StructType(
@@ -578,10 +580,11 @@ class APIInternalSuite extends TestData {
testWithAlteredSessionParameter(() => {
import session.implicits._
val schema = StructType(Seq(StructField("ID", LongType)))
- val largeData = new ArrayBuffer[Row]()
+ val largeDataBuf = new ArrayBuffer[Row]()
for (i <- 0 to 1024) {
- largeData.append(Row(i.toLong))
+ largeDataBuf.append(Row(i.toLong))
}
+ val largeData = largeDataBuf.toSeq
// With specific schema
var df = session.createDataFrame(largeData, schema)
assert(df.snowflakePlan.queries.size == 3)
@@ -596,7 +599,7 @@ class APIInternalSuite extends TestData {
for (i <- 0 to 1024) {
inferData.append(i.toLong)
}
- df = inferData.toDF("id2")
+ df = inferData.toSeq.toDF("id2")
assert(df.snowflakePlan.queries.size == 3)
assert(df.snowflakePlan.queries(0).sql.trim().startsWith("CREATE SCOPED TEMPORARY TABLE"))
assert(df.snowflakePlan.queries(1).sql.trim().startsWith("INSERT INTO"))
@@ -823,6 +826,7 @@ class APIInternalSuite extends TestData {
val (rows, meta) = session.conn.getResultAndMetadata(session.sql(query).snowflakePlan)
assert(rows.length == 0 || rows(0).length == meta.size)
}
+ succeed
}
// reader
@@ -895,7 +899,7 @@ class APIInternalSuite extends TestData {
assert(ex2.errorCode.equals("0321"))
}
- def checkExecuteAndGetQueryId(df: DataFrame): Unit = {
+ def checkExecuteAndGetQueryId(df: DataFrame): Assertion = {
val query = Query.resultScanQuery(df.executeAndGetQueryId())
val res = query.runQueryGetResult(session.conn, mutable.HashMap.empty[String, String], false)
res.attributes
@@ -907,7 +911,7 @@ class APIInternalSuite extends TestData {
checkExecuteAndGetQueryIdWithStatementParameter(df)
}
- private def checkExecuteAndGetQueryIdWithStatementParameter(df: DataFrame): Unit = {
+ private def checkExecuteAndGetQueryIdWithStatementParameter(df: DataFrame): Assertion = {
val testQueryTagValue = s"test_query_tag_${Random.nextLong().abs}"
val queryId = df.executeAndGetQueryId(Map("QUERY_TAG" -> testQueryTagValue))
val rows = session
@@ -1007,7 +1011,7 @@ class APIInternalSuite extends TestData {
largeData.append(
Row(1025, null, null, null, null, null, null, null, null, null, null, null, null))
- val df = session.createDataFrame(largeData, schema)
+ val df = session.createDataFrame(largeData.toSeq, schema)
checkExecuteAndGetQueryId(df)
// Statement parameters are applied for all queries.
@@ -1039,6 +1043,7 @@ class APIInternalSuite extends TestData {
// case 2: test int/boolean parameter
multipleQueriesDF1.executeAndGetQueryId(
Map("STATEMENT_TIMEOUT_IN_SECONDS" -> 100, "USE_CACHED_RESULT" -> false))
+ succeed
}
test("VariantTypes.getType") {
@@ -1052,7 +1057,7 @@ class APIInternalSuite extends TestData {
assert(Variant.VariantTypes.getType("Timestamp") == Variant.VariantTypes.Timestamp)
assert(Variant.VariantTypes.getType("Array") == Variant.VariantTypes.Array)
assert(Variant.VariantTypes.getType("Object") == Variant.VariantTypes.Object)
- intercept[Exception] { Variant.VariantTypes.getType("not_exist_type") }
+ assertThrows[Exception] { Variant.VariantTypes.getType("not_exist_type") }
}
test("HasCachedResult doesn't cache again") {
diff --git a/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala b/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala
index 2509176a..b8363191 100644
--- a/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala
@@ -109,16 +109,16 @@ class DropTempObjectsSuite extends SNTestBase {
TempObjectType.Table,
"db.schema.tempName1",
TempType.Temporary)
- assertTrue(session.getTempObjectMap.contains("db.schema.tempName1"))
+ assert(session.getTempObjectMap.contains("db.schema.tempName1"))
session.recordTempObjectIfNecessary(
TempObjectType.Table,
"db.schema.tempName2",
TempType.ScopedTemporary)
- assertFalse(session.getTempObjectMap.contains("db.schema.tempName2"))
+ assert(!session.getTempObjectMap.contains("db.schema.tempName2"))
session.recordTempObjectIfNecessary(
TempObjectType.Table,
"db.schema.tempName3",
TempType.Permanent)
- assertFalse(session.getTempObjectMap.contains("db.schema.tempName3"))
+ assert(!session.getTempObjectMap.contains("db.schema.tempName3"))
}
}
diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
index 937b93e6..714fceb1 100644
--- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala
@@ -6,7 +6,7 @@ import com.snowflake.snowpark.internal.ParameterUtils.{
MIN_REQUEST_TIMEOUT_IN_SECONDS,
SnowparkRequestTimeoutInSeconds
}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
class ErrorMessageSuite extends FunSuite {
diff --git a/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala b/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala
index 1ada9af9..a78fab76 100644
--- a/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala
@@ -250,6 +250,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
emptyChecker(CurrentRow)
emptyChecker(UnspecifiedFrame)
binaryChecker(SpecifiedWindowFrame(RowFrame, _, _))
+ succeed
}
test("star children and dependent columns") {
@@ -475,6 +476,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
leafAnalyzerChecker(CurrentRow)
leafAnalyzerChecker(UnspecifiedFrame)
binaryAnalyzerChecker(SpecifiedWindowFrame(RowFrame, _, _))
+ succeed
}
test("star - analyze") {
@@ -489,6 +491,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
assert(exp.analyze(x => x) == exp)
assert(exp.analyze(_ => att2) == att2)
leafAnalyzerChecker(Star(Seq.empty))
+ succeed
}
test("WindowSpecDefinition - analyze") {
@@ -988,6 +991,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
assert(key.name == "\"COL3\"")
assert(value.name == "\"COL3\"")
}
+ succeed
}
test("TableDelete - Analyzer") {
@@ -1125,5 +1129,6 @@ class ExpressionAndPlanNodeSuite extends SNTestBase {
leafSimplifierChecker(
SnowflakePlan(Seq.empty, "222", session, None, supportAsyncMode = false))
+ succeed
}
}
diff --git a/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala b/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala
index c3007066..eb3df273 100644
--- a/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala
@@ -5,7 +5,7 @@ import java.util.jar.{JarFile, JarOutputStream}
import java.util.zip.ZipException
import com.snowflake.snowpark.internal.{FatJarBuilder, JavaCodeCompiler}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
diff --git a/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala b/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala
index fbd6101d..04d1cefd 100644
--- a/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala
@@ -1,6 +1,6 @@
package com.snowflake.snowpark
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import com.snowflake.snowpark_test._
import java.io.ByteArrayOutputStream
diff --git a/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala b/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala
index 4e3faa54..a347ecd8 100644
--- a/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala
@@ -1,7 +1,7 @@
package com.snowflake.snowpark
import com.snowflake.snowpark.internal.{InMemoryClassObject, JavaCodeCompiler, UDFClassPath}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
class JavaCodeCompilerSuite extends FunSuite {
diff --git a/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala b/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala
index d3a31d9e..6ca97dd3 100644
--- a/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala
@@ -1,7 +1,7 @@
package com.snowflake.snowpark
import com.snowflake.snowpark.internal.Logging
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
class LoggingSuite extends FunSuite {
diff --git a/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala b/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala
index 86d43968..92f1acc4 100644
--- a/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala
@@ -3,6 +3,7 @@ package com.snowflake.snowpark
import com.snowflake.snowpark.internal.ParameterUtils
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types.IntegerType
+import org.scalatest.Assertion
import scala.language.implicitConversions
@@ -219,7 +220,9 @@ class NewColumnReferenceSuite extends SNTestBase {
case class TestInternalAlias(name: String) extends TestColumnName
implicit def stringToOriginalName(name: String): TestOriginalName =
TestOriginalName(name)
- private def verifyOutputName(output: Seq[Attribute], columnNames: Seq[TestColumnName]): Unit = {
+ private def verifyOutputName(
+ output: Seq[Attribute],
+ columnNames: Seq[TestColumnName]): Assertion = {
assert(output.size == columnNames.size)
assert(output.map(_.name).zip(columnNames).forall {
case (name, TestOriginalName(n)) => name == quoteName(n)
@@ -280,6 +283,7 @@ class NewColumnReferenceSuite extends SNTestBase {
verifyUnaryNode(child => TableUpdate("a", Map.empty, None, Some(child)))
verifyUnaryNode(child => SnowflakeCreateTable("a", SaveMode.Append, Some(child)))
verifyBinaryNode((plan1, plan2) => SimplifiedUnion(Seq(plan1, plan2)))
+ succeed
}
private val project1 = Project(Seq(Attribute("a", IntegerType)), Range(1, 1, 1))
diff --git a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala
index 75365e5b..d461ad5f 100644
--- a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala
@@ -77,6 +77,7 @@ class ParameterSuite extends SNTestBase {
// no need to verify PKCS#1 format key additionally,
// since all Github Action tests use PKCS#1 key to authenticate with Snowflake server.
ParameterUtils.parsePrivateKey(generatePKCS8Key())
+ succeed
}
private def generatePKCS8Key(): String = {
diff --git a/src/test/scala/com/snowflake/snowpark/ReplSuite.scala b/src/test/scala/com/snowflake/snowpark/ReplSuite.scala
index cba06aa4..eeb859e9 100644
--- a/src/test/scala/com/snowflake/snowpark/ReplSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/ReplSuite.scala
@@ -1,15 +1,14 @@
package com.snowflake.snowpark
-import java.io.{BufferedReader, OutputStreamWriter, StringReader}
+import java.io.{BufferedReader, OutputStreamWriter, StringReader, PrintWriter => JPrintWriter}
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths, StandardCopyOption}
-
import com.snowflake.snowpark.internal.Utils
import scala.tools.nsc.Settings
-import scala.tools.nsc.interpreter._
import scala.tools.nsc.util.stringFromStream
import scala.sys.process._
+import scala.tools.nsc.interpreter.shell.{ILoop, ShellConfig}
@UDFTest
class ReplSuite extends TestData {
@@ -48,7 +47,6 @@ class ReplSuite extends TestData {
Console.withOut(outputStream) {
val input = new BufferedReader(new StringReader(preLoad + code))
val output = new JPrintWriter(new OutputStreamWriter(outputStream))
- val repl = new ILoop(input, output)
val settings = new Settings()
if (inMemory) {
settings.processArgumentString("-Yrepl-class-based")
@@ -56,7 +54,8 @@ class ReplSuite extends TestData {
settings.processArgumentString("-Yrepl-class-based -Yrepl-outdir repl_classes")
}
settings.classpath.value = sys.props("java.class.path")
- repl.process(settings)
+ val repl = new ILoop(ShellConfig(settings), input, output)
+ repl.run(settings)
}
}.replaceAll("scala> ", "")
}
diff --git a/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala b/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala
index ebb5f9c4..e13261a2 100644
--- a/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala
@@ -40,6 +40,7 @@ class ResultAttributesSuite extends SNTestBase {
val attribute = getAttributesWithTypes(tableName, integers)
assert(attribute.length == integers.length)
integers.indices.foreach(index => assert(attribute(index).dataType == LongType))
+ succeed
}
test("float data type") {
@@ -47,6 +48,7 @@ class ResultAttributesSuite extends SNTestBase {
val attribute = getAttributesWithTypes(tableName, floats)
assert(attribute.length == floats.length)
floats.indices.foreach(index => assert(attribute(index).dataType == DoubleType))
+ succeed
}
test("string data types") {
@@ -54,6 +56,7 @@ class ResultAttributesSuite extends SNTestBase {
val attribute = getAttributesWithTypes(tableName, strings)
assert(attribute.length == strings.length)
strings.indices.foreach(index => assert(attribute(index).dataType == StringType))
+ succeed
}
test("binary data types") {
@@ -61,6 +64,7 @@ class ResultAttributesSuite extends SNTestBase {
val attribute = getAttributesWithTypes(tableName, binaries)
assert(attribute.length == binaries.length)
binaries.indices.foreach(index => assert(attribute(index).dataType == BinaryType))
+ succeed
}
test("logical data type") {
@@ -69,6 +73,7 @@ class ResultAttributesSuite extends SNTestBase {
assert(attributes.length == 1)
assert(attributes.head.dataType == BooleanType)
dropTable(tableName)
+ succeed
}
test("date & time data type") {
@@ -83,6 +88,7 @@ class ResultAttributesSuite extends SNTestBase {
val attribute = getAttributesWithTypes(tableName, dates.map(_._1))
assert(attribute.length == dates.length)
dates.indices.foreach(index => assert(attribute(index).dataType == dates(index)._2))
+ succeed
}
test("semi-structured data types") {
@@ -106,5 +112,6 @@ class ResultAttributesSuite extends SNTestBase {
index =>
assert(attribute(index).dataType ==
ArrayType(StringType)))
+ succeed
}
}
diff --git a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala
index 1144b207..407e34c1 100644
--- a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala
+++ b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala
@@ -8,12 +8,13 @@ import com.snowflake.snowpark.internal.{ParameterUtils, ServerConnection, UDFCla
import com.snowflake.snowpark.types._
import com.snowflake.snowpark_test.TestFiles
import org.mockito.Mockito.{doReturn, spy, when}
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.{Assertion, Assertions, BeforeAndAfterAll}
+import org.scalatest.funsuite.{AsyncFunSuite => FunSuite}
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.{Await, Future}
import scala.concurrent.duration._
-import scala.concurrent.ExecutionContext.Implicits.global
+import scala.language.postfixOps
trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with SnowTestFiles {
@@ -100,7 +101,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S
}
}
- def checkAnswer(df1: DataFrame, df2: DataFrame, sort: Boolean): Unit = {
+ def checkAnswer(df1: DataFrame, df2: DataFrame, sort: Boolean): Assertion = {
if (sort) {
assert(
TestUtils.compare(df1.collect().sortBy(_.toString), df2.collect().sortBy(_.toString)))
@@ -109,14 +110,15 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S
}
}
- def checkAnswer(df: DataFrame, result: Row): Unit =
- checkResult(df.collect(), Seq(result), false)
+ def checkAnswer(df: DataFrame, result: Row): Assertion =
+ checkResult(df.collect(), Seq(result), sort = false)
- def checkAnswer(df: DataFrame, result: Seq[Row], sort: Boolean = true): Unit =
+ def checkAnswer(df: DataFrame, result: Seq[Row], sort: Boolean = true): Assertion =
checkResult(df.collect(), result, sort)
- def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Unit =
+ def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Assertion =
TestUtils.checkResult(result, expected, sort)
+
def checkResultIterator(result: Iterator[Row], expected: Seq[Row], sort: Boolean = true): Unit =
checkResult(result.toArray, expected, sort)
@@ -171,7 +173,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S
thunk: => T,
parameter: String,
value: String,
- skipIfParamNotExist: Boolean = false): Unit = {
+ skipIfParamNotExist: Boolean = false): Assertion = {
var parameterNotExist = false
try {
session.runQuery(s"alter session set $parameter = $value")
@@ -187,6 +189,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S
// best effort to unset the parameter.
session.runQuery(s"alter session unset $parameter")
}
+ succeed
}
def testWithTimezone[T](thunk: => T, timezone: String): T = {
diff --git a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala
index 5269f4b2..6b299064 100644
--- a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala
@@ -68,7 +68,8 @@ class ServerConnectionSuite extends SNTestBase {
for (i <- 0 to 1024) {
largeData.append(Row(i))
}
- val df2 = session.createDataFrame(largeData, StructType(Seq(StructField("ID", LongType))))
+ val df2 =
+ session.createDataFrame(largeData.toSeq, StructType(Seq(StructField("ID", LongType))))
val ex2 = intercept[SnowparkClientException] {
session.conn.executeAsync(df2.snowflakePlan)
}
diff --git a/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala b/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala
index e7c857a3..c4a40285 100644
--- a/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala
@@ -126,7 +126,8 @@ class SnowflakePlanSuite extends SNTestBase {
for (i <- 0 to 1024) {
largeData.append(Row(i))
}
- val df2 = session.createDataFrame(largeData, StructType(Seq(StructField("ID", LongType))))
+ val df2 =
+ session.createDataFrame(largeData.toSeq, StructType(Seq(StructField("ID", LongType))))
assert(!df2.snowflakePlan.supportAsyncMode && !df2.clone.snowflakePlan.supportAsyncMode)
assert(!session.sql(" put file").snowflakePlan.supportAsyncMode)
assert(!session.sql("get file ").snowflakePlan.supportAsyncMode)
diff --git a/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala b/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala
index 4864c38f..1bc2b98f 100644
--- a/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala
@@ -1,6 +1,6 @@
package com.snowflake.snowpark
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import com.snowflake.snowpark.internal.SnowparkSFConnectionHandler
class SnowparkSFConnectionHandlerSuite extends FunSuite {
diff --git a/src/test/scala/com/snowflake/snowpark/TestUtils.scala b/src/test/scala/com/snowflake/snowpark/TestUtils.scala
index e1abedff..7fa3ea36 100644
--- a/src/test/scala/com/snowflake/snowpark/TestUtils.scala
+++ b/src/test/scala/com/snowflake/snowpark/TestUtils.scala
@@ -18,7 +18,7 @@ import com.snowflake.snowpark.internal.UDFClassPath.getPathForClass
import com.snowflake.snowpark.internal.analyzer.{quoteName, quoteNameWithoutUpperCasing}
import com.snowflake.snowpark.types._
import com.snowflake.snowpark_java.types.{InternalUtils, StructType => JavaStructType}
-import org.scalatest.{BeforeAndAfterAll, Tag}
+import org.scalatest.{Assertion, Assertions, BeforeAndAfterAll, Tag}
import java.util.{Locale, Properties}
import com.snowflake.snowpark.Session.loadConfFromFile
@@ -101,7 +101,7 @@ object TestUtils extends Logging {
s"insert into $name values ${data.map("(" + _.toString + ")").mkString(",")}")
def insertIntoTable(name: String, data: java.util.List[Object], session: Session): Unit =
- insertIntoTable(name, data.asScala.map(_.toString), session)
+ insertIntoTable(name, data.asScala.map(_.toString).toSeq, session)
def uploadFileToStage(
stageName: String,
@@ -275,17 +275,17 @@ object TestUtils extends Logging {
res
}
- def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Unit = {
+ def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Assertion = {
val sorted = if (sort) result.sortBy(_.toString) else result
val sortedExpected = if (sort) expected.sortBy(_.toString) else expected
- assert(
+ Assertions.assert(
compare(sorted, sortedExpected.toArray[Row]),
s"${sorted.map(_.toString).mkString("[", ", ", "]")} != " +
s"${sortedExpected.map(_.toString).mkString("[", ", ", "]")}")
}
- def checkResult(result: Array[Row], expected: java.util.List[Row], sort: Boolean): Unit =
- checkResult(result, expected.asScala, sort)
+ def checkResult(result: Array[Row], expected: java.util.List[Row], sort: Boolean): Assertion =
+ checkResult(result, expected.asScala.toSeq, sort)
def runQueryInSession(session: Session, sql: String): Unit =
session.runQuery(sql)
diff --git a/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala
index 3c3f388d..d4218287 100644
--- a/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala
@@ -48,6 +48,7 @@ class UDFClasspathSuite extends SNTestBase {
assert(mockSession.listFilesInStage(stageName1).size == jarInClassPath.size)
// Assert that no dependencies are uploaded in second time
verify(mockSession, never()).doUpload(any(), any())
+ succeed
}
test("Test that udf function's class path is automatically added") {
@@ -58,6 +59,7 @@ class UDFClasspathSuite extends SNTestBase {
udfR.registerUDF(Some(func), _toUdf((a: Int) => a + a), None)
verify(mockSession, atLeastOnce())
.addDependency(path)
+ succeed
}
test("Test that snowpark jar is NOT uploaded if stage path is available") {
@@ -85,6 +87,7 @@ class UDFClasspathSuite extends SNTestBase {
.addDependency(expectedPath)
// createJavaUDF should be only invoked once
verify(udfR, times(1)).createJavaUDF(any(), any(), any(), any(), any(), any(), any())
+ succeed
}
test(
@@ -106,6 +109,7 @@ class UDFClasspathSuite extends SNTestBase {
}
// createJavaUDF should be invoked twice as it is retried after fixing classpath
verify(udfR, times(2)).createJavaUDF(any(), any(), any(), any(), any(), any(), any())
+ succeed
}
test("Test for getPathForClass") {
diff --git a/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala
index 9a1eae3a..9a3b401c 100644
--- a/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala
@@ -48,6 +48,7 @@ class UDFInternalSuite extends TestData {
}
verify(mockSession, times(1)).removeDependency(path)
verify(mockSession, times(1)).addPackage("com.snowflake:snowpark:latest")
+ succeed
}
test("Test permanent udf not failing back to upload jar", JavaStoredProcExclude) {
@@ -82,6 +83,7 @@ class UDFInternalSuite extends TestData {
}
verify(mockSession, times(1)).removeDependency(path)
verify(mockSession, times(1)).addPackage("com.snowflake:snowpark:latest")
+ succeed
}
test("Test add version logic", JavaStoredProcExclude) {
@@ -103,6 +105,7 @@ class UDFInternalSuite extends TestData {
val path = UDFClassPath.getPathForClass(classOf[com.snowflake.snowpark.Session]).get
verify(mockSession, never()).addDependency(path)
verify(mockSession, times(1)).addPackage(Utils.clientPackageName)
+ succeed
}
test("Confirm jar files to be uploaded to expected location", JavaStoredProcExclude) {
@@ -170,6 +173,7 @@ class UDFInternalSuite extends TestData {
mockSession1.udf.registerPermanent(funcName1, udf, stageName1)
mockSession2.udf.registerPermanent(funcName2, udf, stageName1)
verify(mockSession2, never()).doUpload(any(), any())
+ succeed
} finally {
session.runQuery(s"drop function if exists $funcName1(INT)")
diff --git a/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala
index 1c0984b6..9a5a30ee 100644
--- a/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala
@@ -10,6 +10,7 @@ import scala.reflect.internal.util.BatchSourceFile
import scala.reflect.io.{AbstractFile, VirtualDirectory}
import scala.tools.nsc.GenericRunnerSettings
import scala.tools.nsc.interpreter.IMain
+import scala.tools.nsc.interpreter.shell.ReplReporterImpl
import scala.util.Random
@UDFTest
@@ -107,7 +108,7 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils {
val targetDir = Files.createTempDirectory(s"snowpark_test_target_")
settings.Yreploutdir.value = targetDir.toFile.getAbsolutePath
}
- val interpreter: IMain = new IMain(settings)
+ val interpreter: IMain = new IMain(settings, new ReplReporterImpl(settings))
interpreter.compileSources(new BatchSourceFile(AbstractFile.getFile(new File(fileName))))
interpreter.classLoader.loadClass(s"$packageName.$className")
@@ -123,6 +124,7 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils {
val onDiskName = s"DynamicCompile${Random.nextInt().abs}"
val onDiskClass = generateDynamicClass(packageName, onDiskName, false)
session.udf.handler.addClassToDependencies(onDiskClass)
+ succeed
}
test("ls file") {
diff --git a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala
index 2067212f..854f2adb 100644
--- a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala
@@ -23,6 +23,8 @@ import java.lang.{
Short => JavaShort
}
import net.snowflake.client.jdbc.SnowflakeSQLException
+import org.scalatest.Assertion
+import org.scalatest.Assertions.succeed
import scala.collection.mutable.ArrayBuffer
@@ -40,7 +42,7 @@ class UtilsSuite extends SNTestBase {
Seq("random.jar", "snow-0.3.0.jar", "snowpark", "snowpark.tar.gz").foreach(jarName => {
assert(!Utils.isSnowparkJar(jarName))
})
-
+ succeed
}
test("Logging") {
@@ -50,6 +52,7 @@ class UtilsSuite extends SNTestBase {
test("utils.version") {
// Stored Proc jdbc relies on Utils.Version. This test will prevent changes to this method.
Utils.getClass.getMethod("Version")
+ succeed
}
test("test mask secrets") {
@@ -191,10 +194,11 @@ class UtilsSuite extends SNTestBase {
double: java.lang.Double)
test("Non-nullable types") {
- TypeToSchemaConverter
- .inferSchema[(Boolean, Byte, Short, Int, Long, Float, Double)]()
- .treeString(0) ==
- """root
+ assert(
+ TypeToSchemaConverter
+ .inferSchema[(Boolean, Byte, Short, Int, Long, Float, Double)]()
+ .treeString(0) ==
+ """root
| |--_1: Boolean (nullable = false)
| |--_2: Byte (nullable = false)
| |--_3: Short (nullable = false)
@@ -202,32 +206,33 @@ class UtilsSuite extends SNTestBase {
| |--_5: Long (nullable = false)
| |--_6: Float (nullable = false)
| |--_7: Double (nullable = false)
- |""".stripMargin
+ |""".stripMargin)
}
test("Nullable types") {
- TypeToSchemaConverter
- .inferSchema[(
- Option[Int],
- JavaBoolean,
- JavaByte,
- JavaShort,
- JavaInteger,
- JavaLong,
- JavaFloat,
- JavaDouble,
- Array[Boolean],
- Map[String, Double],
- JavaBigDecimal,
- BigDecimal,
- Variant,
- Geography,
- Date,
- Time,
- Timestamp,
- Geometry)]()
- .treeString(0) ==
- """root
+ assert(
+ TypeToSchemaConverter
+ .inferSchema[(
+ Option[Int],
+ JavaBoolean,
+ JavaByte,
+ JavaShort,
+ JavaInteger,
+ JavaLong,
+ JavaFloat,
+ JavaDouble,
+ Array[Boolean],
+ Map[String, Double],
+ JavaBigDecimal,
+ BigDecimal,
+ Variant,
+ Geography,
+ Date,
+ Time,
+ Timestamp,
+ Geometry)]()
+ .treeString(0) ==
+ """root
| |--_1: Integer (nullable = true)
| |--_2: Boolean (nullable = true)
| |--_3: Byte (nullable = true)
@@ -246,7 +251,7 @@ class UtilsSuite extends SNTestBase {
| |--_16: Time (nullable = true)
| |--_17: Timestamp (nullable = true)
| |--_18: Geometry (nullable = true)
- |""".stripMargin
+ |""".stripMargin)
}
test("normalizeStageLocation") {
@@ -414,6 +419,7 @@ class UtilsSuite extends SNTestBase {
// println(s"test: $name")
Utils.validateObjectName(name)
}
+ succeed
}
test("test Utils.getUDFUploadPrefix()") {
@@ -432,6 +438,7 @@ class UtilsSuite extends SNTestBase {
// println(s"test: $name")
assert(Utils.getUDFUploadPrefix(name).matches("[\\w]+"))
}
+ succeed
}
test("negative test Utils.validateObjectName()") {
@@ -484,6 +491,7 @@ class UtilsSuite extends SNTestBase {
val ex = intercept[SnowparkClientException] { Utils.validateObjectName(name) }
assert(ex.getMessage.replaceAll("\n", "").matches(".*The object name .* is invalid."))
}
+ succeed
}
test("os name") {
@@ -506,6 +514,7 @@ class UtilsSuite extends SNTestBase {
testItems.foreach { item =>
assert(Utils.convertWindowsPathToLinux(item._1).equals(item._2))
}
+ succeed
}
test("Utils.version matches pom version") {
@@ -533,6 +542,7 @@ class UtilsSuite extends SNTestBase {
val sleep7 = Utils.retrySleepTimeInMS(7)
assert(sleep7 >= 30000 && sleep7 <= 60000)
}
+ succeed
}
test("Utils.isRetryable") {
@@ -675,7 +685,7 @@ class UtilsSuite extends SNTestBase {
}
object LoggingTester extends Logging {
- def test(): Unit = {
+ def test(): Assertion = {
// no error report
val password = "failed_error_log_password"
logInfo(s"info PASSWORD=$password")
@@ -692,5 +702,6 @@ object LoggingTester extends Logging {
logWarning(s"warning PASSWORD=$password", exception)
logError(s"error PASSWORD=$password", exception)
+ succeed
}
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala b/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala
index 84ad53bf..dd7cd02c 100644
--- a/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala
@@ -4,10 +4,9 @@ import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.types._
import com.snowflake.snowpark._
import net.snowflake.client.jdbc.SnowflakeSQLException
-import org.scalatest.{BeforeAndAfterEach, Tag}
+import org.scalatest.{Assertion, BeforeAndAfterEach, Tag}
import scala.concurrent.Future
-import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Failure, Random, Success}
class AsyncJobSuite extends TestData with BeforeAndAfterEach {
@@ -352,7 +351,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach {
// This function is copied from DataFrameReader.testReadFile
def testReadFile(testName: String, testTags: Tag*)(
- thunk: (() => DataFrameReader) => Unit): Unit = {
+ thunk: (() => DataFrameReader) => Assertion): Unit = {
// test select
test(testName + " - SELECT", testTags: _*) {
thunk(() => session.read)
@@ -501,6 +500,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach {
df.write.mode(SaveMode.Overwrite).async.saveAsTable(list).getResult()
checkAnswer(session.table(tableName), Seq(Row(1), Row(2), Row(3)))
dropTable(tableName)
+ succeed
} finally {
dropTable(tableName)
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala
index df2e0e49..c2d9f477 100644
--- a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala
@@ -240,26 +240,26 @@ class ColumnSuite extends TestData {
checkAnswer(df.select(df("one")), Row(1) :: Nil)
checkAnswer(df.select(df("oNe")), Row(1) :: Nil)
checkAnswer(df.select(df(""""ONE"""")), Row(1) :: Nil)
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df.col(""""One"""")
}
df = Seq((1)).toDF("One One")
checkAnswer(df.select(df("One One")), Row(1) :: Nil)
checkAnswer(df.select(df("\"One One\"")), Row(1) :: Nil)
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df.col(""""one one"""")
}
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df("one one")
}
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df(""""ONE ONE"""")
}
df = Seq((1)).toDF(""""One One"""")
checkAnswer(df.select(df(""""One One"""")), Row(1) :: Nil)
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df.col(""""ONE ONE"""")
}
}
@@ -333,6 +333,7 @@ class ColumnSuite extends TestData {
val colsResolved = df.schema.fields.map(_.name).map(df.apply).toSeq
val df2 = df.select(colsUnresolved ++ colsResolved)
df2.collect()
+ succeed
} finally {
session.sql(s"drop function ${temp}.${udfName}(integer)").collect()
session.sql(s"drop table ${temp}.${sName}").collect()
@@ -346,13 +347,13 @@ class ColumnSuite extends TestData {
checkAnswer(df.select(col("\"col\"\"\"")), Row(1) :: Nil)
checkAnswer(df.select(col("\"col\"")), Row(2) :: Nil)
checkAnswer(df.select(col("\"\"\"col\"")), Row(3) :: Nil)
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(col("\"col\"\"")).collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(col("\"\"col\"")).collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(col("\"col\"\"\"\"")).collect()
}
}
@@ -366,13 +367,13 @@ class ColumnSuite extends TestData {
checkAnswer(df.select(col("COL")), Row(1) :: Nil)
checkAnswer(df.select(col("CoL")), Row(1) :: Nil)
checkAnswer(df.select(col("\"COL\"")), Row(1) :: Nil)
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(col("\"Col\"")).collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(col("COL .")).collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(col("\"CoL\"")).collect()
}
}
@@ -386,13 +387,13 @@ class ColumnSuite extends TestData {
checkAnswer(df.select($"COL"), Row(1) :: Nil)
checkAnswer(df.select($"CoL"), Row(1) :: Nil)
checkAnswer(df.select($""""COL""""), Row(1) :: Nil)
- intercept[Exception] {
+ assertThrows[Exception] {
df.select($""""Col"""").collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select($"COL .").collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select($""""CoL"""").collect()
}
}
@@ -406,10 +407,10 @@ class ColumnSuite extends TestData {
checkAnswer(df.select("COL"), Row(1) :: Nil)
checkAnswer(df.select("CoL"), Row(1) :: Nil)
checkAnswer(df.select("\"COL\""), Row(1) :: Nil)
- intercept[Exception] {
+ assertThrows[Exception] {
df.select("\"Col\"").collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select("COL .").collect()
}
}
@@ -433,16 +434,16 @@ class ColumnSuite extends TestData {
checkAnswer(df.select(sqlExpr("\"col\" + 10")), Row(12) :: Nil)
checkAnswer(df.filter(sqlExpr("col < 1")), Nil)
checkAnswer(df.filter(sqlExpr("\"col\" = 2")).select(col("col")), Row(1) :: Nil)
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(sqlExpr("\"Col\"")).collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(sqlExpr("COL .")).collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(sqlExpr("\"CoL\"")).collect()
}
- intercept[Exception] {
+ assertThrows[Exception] {
df.select(sqlExpr("col .")).collect()
}
}
@@ -709,7 +710,7 @@ class ColumnSuite extends TestData {
new Date(timestamp - 100)))
}
- val df = session.createDataFrame(largeData, schema)
+ val df = session.createDataFrame(largeData.toSeq, schema)
// scala style checks dosn't support to put all of these expression in one filter()
// So split it as 2 steps.
val df2 = df.filter(
diff --git a/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala
index 485bf4c2..e59398a3 100644
--- a/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala
@@ -944,6 +944,7 @@ class CopyableDataFrameSuite extends SNTestBase {
val cloned = df.clone
assert(cloned.isInstanceOf[CopyableDataFrame])
cloned.copyInto(testTableName, Seq(col("$1").as("B")))
+ succeed
} finally {
dropTable(testTableName)
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala
index 3753a480..445fb392 100644
--- a/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala
@@ -3,7 +3,6 @@ package com.snowflake.snowpark_test
import com.snowflake.snowpark.functions._
import com.snowflake.snowpark._
import net.snowflake.client.jdbc.SnowflakeSQLException
-import org.scalatest.Matchers.the
import java.sql.ResultSet
@@ -195,6 +194,7 @@ class DataFrameAggregateSuite extends TestData {
assert(values_1_B.contains(row.getInt(1)) && values_2_B.contains(row.getInt(2)))
}
}
+ succeed
}
test("RelationalGroupedDataFrame.avg()/mean()") {
@@ -308,7 +308,7 @@ class DataFrameAggregateSuite extends TestData {
// Used temporary VIEW which is not supported by owner's mode stored proc yet
test("Window functions inside aggregate functions", JavaStoredProcExcludeOwner) {
def checkWindowError(df: => DataFrame): Unit = {
- the[SnowflakeSQLException] thrownBy {
+ assertThrows[SnowflakeSQLException] {
df.collect()
}
}
@@ -477,13 +477,13 @@ class DataFrameAggregateSuite extends TestData {
/* TODO: Add another test with eager analysis
*/
- intercept[SnowflakeSQLException] {
+ assertThrows[SnowflakeSQLException] {
courseSales.groupBy().agg(grouping($"course")).collect()
}
/*
* TODO: Add another test with eager analysis
*/
- intercept[SnowflakeSQLException] {
+ assertThrows[SnowflakeSQLException] {
courseSales.groupBy().agg(grouping_id($"course")).collect()
}
}
@@ -546,6 +546,7 @@ class DataFrameAggregateSuite extends TestData {
checkAnswer(kurtosisVal, Seq(Row(aggKurtosisResult.getDouble(1))))
}
statement.close()
+ succeed
}
test("SN - zero moments") {
@@ -699,7 +700,7 @@ class DataFrameAggregateSuite extends TestData {
checkAnswer(session.sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"), Nil)
- val error = intercept[SnowflakeSQLException] {
+ assertThrows[SnowflakeSQLException] {
session.sql("SELECT COUNT_IF(x) FROM tempView").collect()
}
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala
index ce04e229..82962997 100644
--- a/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala
@@ -222,6 +222,7 @@ trait DataFrameJoinSuite extends SNTestBase {
assert(ex2.getMessage.contains("INTCOL"))
assert(ex2.getMessage().contains("ambiguous"))
}
+ succeed
}
test("join -- expressions on ambiguous columns") {
@@ -304,7 +305,7 @@ trait DataFrameJoinSuite extends SNTestBase {
lhs.join(rhs, Seq("intcol"), joinType).select(lhs("negcol"), rhs("negcol")),
Row(-1, -10) :: Row(-2, -20) :: Nil)
}
-
+ succeed
}
test("Columns with and without quotes") {
@@ -400,6 +401,7 @@ trait DataFrameJoinSuite extends SNTestBase {
assert(ex.message.contains(msg))
}
}
+ succeed
}
test("clone can help these self joins") {
@@ -616,6 +618,7 @@ trait DataFrameJoinSuite extends SNTestBase {
val dfOne = df.select(lit(1).as("a"))
val dfTwo = session.range(10).select(lit(2).as("b"))
dfOne.join(dfTwo, $"a" === $"b", "left").collect()
+ succeed
}
test("name alias in multiple join") {
@@ -645,6 +648,7 @@ trait DataFrameJoinSuite extends SNTestBase {
df_end_stations("station_name"),
df_trips("starttime"))
.collect()
+ succeed
} finally {
dropTable(tableTrips)
@@ -679,6 +683,7 @@ trait DataFrameJoinSuite extends SNTestBase {
df_end_stations("station%name"),
df_trips("starttime"))
.collect()
+ succeed
} finally {
dropTable(tableTrips)
@@ -806,6 +811,7 @@ trait DataFrameJoinSuite extends SNTestBase {
|-------------------
|""".stripMargin)
df.select(dfRight("*"), dfRight("c")).show()
+ succeed
}
test("select columns on join result with conflict name", JavaStoredProcExclude) {
diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala
index 2146cdce..28b84d8b 100644
--- a/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala
@@ -105,6 +105,7 @@ class DataFrameRangeSuite extends SNTestBase {
assert(res.head.getLong(1) == expSum)
}
}
+ succeed
}
test("Session range with Max and Min") {
diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala
index 82cd6e7f..89b87c87 100644
--- a/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala
@@ -4,7 +4,7 @@ import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.types._
import com.snowflake.snowpark._
import net.snowflake.client.jdbc.SnowflakeSQLException
-import org.scalatest.Tag
+import org.scalatest.{Assertion, Tag}
import scala.util.Random
@@ -226,6 +226,7 @@ class DataFrameReaderSuite extends SNTestBase {
.csv(path),
result)
})
+ succeed
} finally {
runQuery(s"drop file format $formatName", session)
}
@@ -338,6 +339,7 @@ class DataFrameReaderSuite extends SNTestBase {
.parquet(s"@$tmpStageName/$ctype/")
.collect()
})
+ succeed
})
testReadFile("read parquet with no schema")(reader => {
@@ -502,7 +504,7 @@ class DataFrameReaderSuite extends SNTestBase {
})
def testReadFile(testName: String, testTags: Tag*)(
- thunk: (() => DataFrameReader) => Unit): Unit = {
+ thunk: (() => DataFrameReader) => Assertion): Unit = {
// test select
test(testName + " - SELECT", testTags: _*) {
thunk(() => session.read)
diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala
index 39629dec..e4a76c93 100644
--- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala
@@ -4,6 +4,7 @@ import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.types.IntegerType
import com.snowflake.snowpark.{Column, Row, SnowparkClientException, TestData}
import net.snowflake.client.jdbc.SnowflakeSQLException
+import org.scalatest.Assertion
import java.sql.{Date, Timestamp}
@@ -11,7 +12,7 @@ class DataFrameSetOperationsSuite extends TestData {
import session.implicits._
test("Union with filters") {
- def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = {
+ def check(newCol: Column, filter: Column, result: Seq[Row]): Assertion = {
val df1 = session.createDataFrame(Seq((1, 1))).toDF("a", "b").withColumn("c", newCol)
val df2 = df1.union(df1).withColumn("d", lit(100)).filter(filter)
@@ -27,7 +28,7 @@ class DataFrameSetOperationsSuite extends TestData {
}
test("Union All with filters") {
- def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = {
+ def check(newCol: Column, filter: Column, result: Seq[Row]): Assertion = {
val df1 = session.createDataFrame(Seq((1, 1))).toDF("a", "b").withColumn("c", newCol)
val df2 = df1.unionAll(df1).withColumn("d", lit(100)).filter(filter)
@@ -161,7 +162,7 @@ class DataFrameSetOperationsSuite extends TestData {
df1 = Seq((1, 2, 3)).toDF("a", "b", "c")
df2 = Seq((4, 5, 6)).toDF("a", "c", "d")
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df1.unionByName(df2)
}
}
@@ -182,7 +183,7 @@ class DataFrameSetOperationsSuite extends TestData {
df1 = Seq((1, 2, 3)).toDF("a", "b", "c")
df2 = Seq((4, 5, 6)).toDF("a", "c", "d")
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df1.unionAllByName(df2)
}
}
@@ -196,7 +197,7 @@ class DataFrameSetOperationsSuite extends TestData {
df1 = Seq((1, 2, 3)).toDF(""""a"""", "b", "c")
df2 = Seq((4, 5, 6)).toDF("a", "c", "b")
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df1.unionByName(df2)
}
}
@@ -210,7 +211,7 @@ class DataFrameSetOperationsSuite extends TestData {
df1 = Seq((1, 2, 3)).toDF(""""a"""", "b", "c")
df2 = Seq((4, 5, 6)).toDF("a", "c", "b")
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df1.unionAllByName(df2)
}
}
@@ -255,6 +256,7 @@ class DataFrameSetOperationsSuite extends TestData {
dates.union(widenTypedRows).collect()
dates.except(widenTypedRows).collect()
dates.intersect(widenTypedRows).collect()
+ succeed
}
/*
@@ -271,7 +273,7 @@ class DataFrameSetOperationsSuite extends TestData {
}
df1 = Seq((1, 1)).toDF("c0", "c1")
df2 = Seq((1, 1)).toDF(c0, c1)
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df1.unionByName(df2)
}
}
@@ -286,7 +288,7 @@ class DataFrameSetOperationsSuite extends TestData {
}
df1 = Seq((1, 1)).toDF("c0", "c1")
df2 = Seq((1, 1)).toDF(c0, c1)
- intercept[SnowparkClientException] {
+ assertThrows[SnowparkClientException] {
df1.unionAllByName(df2)
}
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala
index 45aec853..67ce38f5 100644
--- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala
@@ -5,7 +5,8 @@ import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.internal.analyzer._
import com.snowflake.snowpark.types._
import net.snowflake.client.jdbc.SnowflakeSQLException
-import org.scalatest.BeforeAndAfterEach
+import org.scalatest.{Assertion, BeforeAndAfterEach}
+
import java.sql.{Date, Time, Timestamp}
import scala.util.Random
@@ -159,7 +160,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach {
skipIfParamNotExist = true)
}
- private def testCacheResult(): Unit = {
+ private def testCacheResult(): Assertion = {
val tableName = randomName()
runQuery(s"create table $tableName (num int)", session)
runQuery(s"insert into $tableName values(1),(2)", session)
@@ -446,8 +447,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach {
test("df.stat.approxQuantile", JavaStoredProcExclude) {
assert(approxNumbers.stat.approxQuantile("a", Array(0.5))(0).get == 4.5)
assert(
- approxNumbers.stat.approxQuantile("a", Array(0, 0.1, 0.4, 0.6, 1)).deep ==
- Array(Some(0.0), Some(0.9), Some(3.6), Some(5.3999999999999995), Some(9.0)).deep)
+ approxNumbers.stat.approxQuantile("a", Array(0, 0.1, 0.4, 0.6, 1)).toSeq ==
+ Seq(Some(0.0), Some(0.9), Some(3.6), Some(5.3999999999999995), Some(9.0)))
// Probability out of range error and apply on string column error.
assertThrows[SnowflakeSQLException](approxNumbers.stat.approxQuantile("a", Array(-1d)))
@@ -457,8 +458,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach {
assert(session.table(tableName).stat.approxQuantile("num", Array(0.5))(0).isEmpty)
val res = double2.stat.approxQuantile(Array("a", "b"), Array(0, 0.1, 0.6))
- assert(res(0).deep == Array(Some(0.1), Some(0.12000000000000001), Some(0.22)).deep)
- assert(res(1).deep == Array(Some(0.5), Some(0.52), Some(0.62)).deep)
+ assert(res(0).toSeq == Seq(Some(0.1), Some(0.12000000000000001), Some(0.22)))
+ assert(res(1).toSeq == Seq(Some(0.5), Some(0.52), Some(0.62)))
// ApproxNumbers2 contains a column called T, which conflicts with tmpColumnName.
// This test demos that the query still works.
@@ -630,12 +631,12 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach {
assert(nullData1.first().get == Row(null))
assert(integer1.filter(col("a") < 0).first().isEmpty)
- integer1.first(2) sameElements Seq(Row(1), Row(2))
+ assert(integer1.first(2) sameElements Seq(Row(1), Row(2)))
// return all elements
- integer1.first(3) sameElements Seq(Row(1), Row(2), Row(3))
- integer1.first(10) sameElements Seq(Row(1), Row(2), Row(3))
- integer1.first(-10) sameElements Seq(Row(1), Row(2), Row(3))
+ assert(integer1.first(3) sameElements Seq(Row(1), Row(2), Row(3)))
+ assert(integer1.first(10) sameElements Seq(Row(1), Row(2), Row(3)))
+ assert(integer1.first(-10) sameElements Seq(Row(1), Row(2), Row(3)))
}
test("sample() with row count") {
@@ -726,7 +727,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach {
weights: Array[Double],
index: Int,
count: Long,
- TotalCount: Long): Unit = {
+ TotalCount: Long): Assertion = {
val expectedRowCount = TotalCount * weights(index) / weights.sum
assert(Math.abs(expectedRowCount - count) < expectedRowCount * samplingDeviation)
}
@@ -1514,6 +1515,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach {
test("createDataFrame from empty Seq with schema inference") {
Seq.empty[(Int, Int)].toDF("a", "b").schema.printTreeString()
+ succeed
}
test("schema inference binary type") {
@@ -2182,11 +2184,11 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach {
val row1 = df1.collect().head
// result is non-deterministic.
(row1.getInt(0), row1.getInt(1), row1.getInt(2), row1.getInt(3)) match {
- case (1, 1, 1, 1) =>
- case (1, 1, 1, 2) =>
- case (1, 1, 2, 3) =>
- case (1, 2, 3, 4) =>
- case _ => throw new Exception("wrong result")
+ case (1, 1, 1, 1) => succeed
+ case (1, 1, 1, 2) => succeed
+ case (1, 1, 2, 3) => succeed
+ case (1, 2, 3, 4) => succeed
+ case _ => fail("wrong result")
}
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala
index 0d19d5e8..2e1bcc14 100644
--- a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala
@@ -42,6 +42,7 @@ class DataTypeSuite extends SNTestBase {
}
Seq(ByteType, ShortType, IntegerType, LongType).foreach(verifyIntegralType)
+ succeed
}
test("FractionalType") {
@@ -55,6 +56,7 @@ class DataTypeSuite extends SNTestBase {
}
Seq(FloatType, DoubleType).foreach(verifyIntegralType)
+ succeed
}
test("DecimalType") {
diff --git a/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala
index aac95147..9dba1fe8 100644
--- a/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala
@@ -162,6 +162,7 @@ class FileOperationSuite extends SNTestBase {
assert(secondResult.length == 3)
// On GCP, the files are not skipped if target file already exists
secondResult.map(row => assert(row.status.equals("SKIPPED") || row.status.equals("UPLOADED")))
+ succeed
}
test("put() negative test") {
@@ -453,6 +454,7 @@ class FileOperationSuite extends SNTestBase {
fileName = s"streamFile_${TestUtils.randomString(5)}.csv"
testStreamRoundTrip(s"$schema.$tempStage/$fileName", s"$schema.$tempStage/$fileName.gz", true)
+ succeed
}
@@ -469,6 +471,7 @@ class FileOperationSuite extends SNTestBase {
s"$randomNewSchema.$tempStage/$fileName",
s"$randomNewSchema.$tempStage/$fileName.gz",
true)
+ succeed
} finally {
session.sql(s"DROP SCHEMA $randomNewSchema").collect()
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
index e473de12..cc07015b 100644
--- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala
@@ -235,6 +235,7 @@ trait FunctionSuite extends TestData {
val df = session.sql("select 1")
df.select(random(123)).collect()
df.select(random()).collect()
+ succeed
}
test("sqrt") {
@@ -1907,6 +1908,7 @@ trait FunctionSuite extends TestData {
approx_percentile_estimate(approx_percentile_accumulate(col("a")), 0.5)),
approxNumbers.select(approx_percentile(col("a"), 0.5)),
sort = false)
+ succeed
}
test("approx_percentile_combine") {
@@ -2136,6 +2138,7 @@ trait FunctionSuite extends TestData {
assert(row.getInt(0) >= 1)
assert(row.getInt(0) <= 5)
})
+ succeed
}
test("listagg") {
diff --git a/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala b/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala
index 7ffe6950..159d5e93 100644
--- a/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala
@@ -1,6 +1,6 @@
package com.snowflake.snowpark_test
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import org.scalatest.exceptions.TestFailedException
import scala.language.postfixOps
diff --git a/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala b/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala
index 7bec5be4..796935e5 100644
--- a/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala
@@ -1,6 +1,6 @@
package com.snowflake.snowpark_test
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import com.snowflake.snowpark.internal.JavaUtils._
import com.snowflake.snowpark.types.Variant
import com.snowflake.snowpark_java.types.{Variant => JavaVariant}
diff --git a/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala
index a2f70f98..e91f7a9d 100644
--- a/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala
@@ -69,6 +69,7 @@ class LargeDataFrameSuite extends TestData {
(0 until (result.length - 1)).foreach(index =>
assert(result(index).getInt(0) < result(index + 1).getInt(0)))
+ succeed
}
test("createDataFrame for large values: basic types") {
@@ -128,12 +129,13 @@ class LargeDataFrameSuite extends TestData {
largeData.append(
Row(1025, null, null, null, null, null, null, null, null, null, null, null, null))
- val result = session.createDataFrame(largeData, schema)
+ val largeDataSeq = largeData.toSeq
+ val result = session.createDataFrame(largeDataSeq, schema)
// byte, short, int, long are converted to long
// float and double are converted to double
result.schema.printTreeString()
assert(getSchemaString(result.schema) == schemaString)
- checkAnswer(result.sort(col("id")), largeData, false)
+ checkAnswer(result.sort(col("id")), largeDataSeq, sort = false)
}
test("createDataFrame for large values: time") {
@@ -146,7 +148,7 @@ class LargeDataFrameSuite extends TestData {
}
largeData.append(Row(rowCount, null))
- val df = session.createDataFrame(largeData, schema)
+ val df = session.createDataFrame(largeData.toSeq, schema)
assert(
getSchemaString(df.schema) ==
"""root
@@ -160,7 +162,7 @@ class LargeDataFrameSuite extends TestData {
expected.append(Row(i.toLong, snowflakeTime))
}
expected.append(Row(rowCount, null))
- checkAnswer(df.sort(col("id")), expected, sort = false)
+ checkAnswer(df.sort(col("id")), expected.toSeq, sort = false)
}
// In the result, Array, Map and Geography are String data
@@ -188,7 +190,7 @@ class LargeDataFrameSuite extends TestData {
}
largeData.append(Row(rowCount, null, null, null, null, null))
- val df = session.createDataFrame(largeData, schema)
+ val df = session.createDataFrame(largeData.toSeq, schema)
assert(
getSchemaString(df.schema) ==
"""root
@@ -224,7 +226,7 @@ class LargeDataFrameSuite extends TestData {
|}""".stripMargin)))
}
expected.append(Row(rowCount, null, null, null, null, null))
- checkAnswer(df.sort(col("id")), expected, sort = false)
+ checkAnswer(df.sort(col("id")), expected.toSeq, sort = false)
}
test("createDataFrame for large values: variant in array and map") {
@@ -240,13 +242,13 @@ class LargeDataFrameSuite extends TestData {
Row(i.toLong, Array(new Variant(1), new Variant("\"'")), Map("a" -> new Variant("\"'"))))
}
largeData.append(Row(rowCount, null, null))
- val df = session.createDataFrame(largeData, schema)
+ val df = session.createDataFrame(largeData.toSeq, schema)
val expected = new ArrayBuffer[Row]()
for (i <- 0 until rowCount) {
expected.append(Row(i.toLong, "[\n 1,\n \"\\\"'\"\n]", "{\n \"a\": \"\\\"'\"\n}"))
}
expected.append(Row(rowCount, null, null))
- checkAnswer(df.sort(col("id")), expected, sort = false)
+ checkAnswer(df.sort(col("id")), expected.toSeq, sort = false)
}
test("createDataFrame for large values: geography in array and map") {
@@ -269,7 +271,7 @@ class LargeDataFrameSuite extends TestData {
"b" -> Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[300,100]}"))))
}
largeData.append(Row(rowCount, null, null))
- val df = session.createDataFrame(largeData, schema)
+ val df = session.createDataFrame(largeData.toSeq, schema)
val expected = new ArrayBuffer[Row]()
for (i <- 0 until rowCount) {
expected.append(
@@ -281,7 +283,7 @@ class LargeDataFrameSuite extends TestData {
" 300,\n 100\n ],\n \"type\": \"Point\"\n }\n}"))
}
expected.append(Row(rowCount, null, null))
- checkAnswer(df.sort(col("id")), expected, sort = false)
+ checkAnswer(df.sort(col("id")), expected.toSeq, sort = false)
}
test("test large ResultSet with multiple chunks") {
diff --git a/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala
index dcb35cdb..edb02a85 100644
--- a/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala
@@ -111,6 +111,7 @@ class PermanentUDFSuite extends TestData {
checkAnswer(df.select(callUDF(permFuncName, df("a"))), Seq(Row(2), Row(3)))
runQuery(s"drop function $tempFuncName(INT)", session)
runQuery(s"drop function $permFuncName(INT)", session)
+ succeed
} finally {
runQuery(s"drop function if exists $tempFuncName(INT)", session)
runQuery(s"drop function if exists $permFuncName(INT)", session)
@@ -153,6 +154,7 @@ class PermanentUDFSuite extends TestData {
}
assert(ex.getMessage.matches(".*The object name .* is invalid."))
}
+ succeed
}
test("Clean up uploaded jar files if UDF registration fails", JavaStoredProcExclude) {
diff --git a/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala
index 93dbc99c..5a85c412 100644
--- a/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala
@@ -58,6 +58,7 @@ class ResultSchemaSuite extends TestData {
verifySchema(
"alter session set ABORT_DETACHED_QUERY=false",
session.sql("alter session set ABORT_DETACHED_QUERY=false").schema)
+ succeed
}
test("list, remove file") {
@@ -75,10 +76,12 @@ class ResultSchemaSuite extends TestData {
verifySchema(
s"remove @$stageName/$testFileCsv",
session.sql(s"remove @$stageName/$testFile2Csv").schema)
+ succeed
}
test("select") {
verifySchema(s"select * from $tableName", session.sql(s"select * from $tableName").schema)
+ succeed
}
test("analyze schema") {
@@ -90,6 +93,7 @@ class ResultSchemaSuite extends TestData {
verifySchema(
s"""select string, "int", array, "date" from $fullTypesTable where \"int\" > 0""",
df2.schema)
+ succeed
}
// ignore it for now since we are modifying the analyzer system.
@@ -146,6 +150,7 @@ class ResultSchemaSuite extends TestData {
assert(tsSchema(index).dataType == typeMap(index).tsType)
})
statement.close()
+ succeed
}
test("verify Geography schema type") {
@@ -173,6 +178,7 @@ class ResultSchemaSuite extends TestData {
assert(resultMeta.getColumnType(1) == Types.BINARY)
assert(tsSchema.head.dataType == GeographyType)
statement.close()
+ succeed
} finally {
// Assign output format to the default value
runQuery(s"alter session set GEOGRAPHY_OUTPUT_FORMAT='GeoJSON'", session)
@@ -204,6 +210,7 @@ class ResultSchemaSuite extends TestData {
assert(resultMeta.getColumnType(1) == Types.BINARY)
assert(tsSchema.head.dataType == GeometryType)
statement.close()
+ succeed
} finally {
// Assign output format to the default value
runQuery(s"alter session set GEOMETRY_OUTPUT_FORMAT='GeoJSON'", session)
@@ -217,5 +224,6 @@ class ResultSchemaSuite extends TestData {
assert(resultMeta.getColumnType(1) == Types.TIME)
assert(tsSchema.head.dataType == TimeType)
statement.close()
+ succeed
}
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala
index 35e8c572..757fd728 100644
--- a/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala
@@ -1,7 +1,7 @@
package com.snowflake.snowpark_test
import com.snowflake.snowpark.types.{Geography, Variant}
-import org.scalatest.FunSuite
+import org.scalatest.funsuite.{AnyFunSuite => FunSuite}
import java.io.UncheckedIOException
import java.sql.{Date, Time, Timestamp}
@@ -64,7 +64,7 @@ class ScalaVariantSuite extends FunSuite {
assert(vFloat.asLong() == 1L)
assert(vFloat.asInt() == 1)
assert(vFloat.asShort() == 1.toShort)
- assert((vFloat.asBigDecimal() - BigDecimal("1.1")).abs.doubleValue() < 0.0000001)
+ assert((vFloat.asBigDecimal() - BigDecimal("1.1")).abs.doubleValue < 0.0000001)
assert(vFloat.asBigInt() == BigInt("1"))
assert(vFloat.asTimestamp().equals(new Timestamp(1L)))
assert(vFloat.asString().equals("1.1"))
diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
index 9b82ba7b..9d8a990b 100644
--- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala
@@ -47,6 +47,7 @@ class SessionSuite extends SNTestBase {
t1.run()
t2.run()
+ succeed
}
test("Test for get or create session") {
diff --git a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala
index a0ae3be0..75442807 100644
--- a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala
@@ -76,7 +76,9 @@ trait SqlSuite extends SNTestBase {
// add spaces to the query
val putQuery =
- s" put ${TestUtils.escapePath(path.toString).replace("file:/", "file:///")} @$stageName "
+ s" put ${TestUtils
+ .escapePath(path.toString)
+ .replace("file:/", "file:" + "/" + "/" + "/")} @$stageName "
val put = session.sql(putQuery)
put.schema.printTreeString()
// should upload nothing
@@ -99,6 +101,7 @@ trait SqlSuite extends SNTestBase {
// TODO: Below assertion failed on GCP because JDBC bug SNOW-493080
// Disable this check temporally
// assert(new File(s"$outputPath/$fileName.gz").exists())
+ succeed
} finally {
// remove tmp file
new Directory(new File(outputPath)).deleteRecursively()
diff --git a/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala b/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala
index 8f214271..7cf8881e 100644
--- a/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala
@@ -145,6 +145,7 @@ class StoredProcedureSuite extends SNTestBase {
test("decimal input") {
val sp = session.sproc.registerTemporary((_: Session, num: java.math.BigDecimal) => num)
session.storedProcedure(sp, java.math.BigDecimal.valueOf(123)).show()
+ succeed
}
test("binary type") {
@@ -2286,6 +2287,7 @@ println(s"""
assert(newSession.sql(s"show procedures like '$name1'").collect().isEmpty)
assert(newSession.sql(s"show procedures like '$name2'").collect().isEmpty)
newSession.close()
+ succeed
}
// temporary disabled, waiting for server side JDBC upgrade
diff --git a/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala
index 2d5357fe..0cc490a9 100644
--- a/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala
@@ -181,6 +181,7 @@ class TableSuite extends TestData {
checkAnswer(session.table(name2), Seq(Row(1), Row(2), Row(3)))
dropTable(name2)
+ succeed
}
test("read from different schema") {
@@ -308,6 +309,7 @@ class TableSuite extends TestData {
df.write.mode(SaveMode.Overwrite).saveAsTable(list)
checkAnswer(session.table(tableName), Seq(Row(1), Row(2), Row(3)))
dropTable(tableName)
+ succeed
} finally {
dropTable(tableName)
diff --git a/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala
index 61ed76a0..a106e567 100644
--- a/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala
@@ -89,15 +89,15 @@ trait UDFSuite extends TestData {
}
assert(ex.getMessage.contains("mutable.Map[String,String]"))
- intercept[UnsupportedOperationException] {
+ assertThrows[UnsupportedOperationException] {
udf((x: java.lang.Integer) => x + x)
}
- intercept[UnsupportedOperationException] {
+ assertThrows[UnsupportedOperationException] {
udf((x: Option[String]) => x.map(a => s"$a"))
}
- intercept[UnsupportedOperationException] {
+ assertThrows[UnsupportedOperationException] {
udf((x: mutable.Map[String, String]) => x.keys)
}
}
@@ -143,6 +143,7 @@ trait UDFSuite extends TestData {
assert(result.size == 2)
Seq(1, 2, 3).foreach(i => assert(result(0).getString(0).contains(s"convertToMap$i")))
Seq(4, 5, 6).foreach(i => assert(result(1).getString(0).contains(s"convertToMap$i")))
+ succeed
}
test("UDF with multiple args of type map, array etc") {
@@ -2816,6 +2817,7 @@ class AlwaysCleanUDFSuite extends UDFSuite with AlwaysCleanSession {
val myDf = session.sql("select 'Raymond' NAME")
val readFileUdf = udf(TestClassWithoutFieldAccess.run)
myDf.withColumn("CONCAT", readFileUdf(col("NAME"))).show()
+ succeed
}
}
@@ -2824,7 +2826,7 @@ class NeverCleanUDFSuite extends UDFSuite with NeverCleanSession {
test("Test without closure cleaner") {
val myDf = session.sql("select 'Raymond' NAME")
// Without closure cleaner, this test will throw error
- intercept[NotSerializableException] {
+ assertThrows[NotSerializableException] {
val readFileUdf = udf(TestClassWithoutFieldAccess.run)
myDf.withColumn("CONCAT", readFileUdf(col("NAME"))).show()
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala
index d6845bde..fec0da6d 100644
--- a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala
@@ -12,7 +12,7 @@ import com.snowflake.snowpark.internal._
import com.snowflake.snowpark.types._
import com.snowflake.snowpark.udtf._
-import scala.collection.{Seq, mutable}
+import scala.collection.mutable
@UDFTest
class UDTFSuite extends TestData {
@@ -292,6 +292,7 @@ class UDTFSuite extends TestData {
session.udtf.registerTemporary(funcName, new MyDuplicateRegisterUDTF1())
session.udtf.registerTemporary(funcName2, new MyDuplicateRegisterUDTF1())
+ succeed
} finally {
runQuery(s"drop function if exists $funcName(STRING)", session)
}
@@ -421,6 +422,7 @@ class UDTFSuite extends TestData {
|""".stripMargin)
checkAnswer(df, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false)
}
+ succeed
} finally {
runQuery(s"drop function if exists $funcName(NUMBER, NUMBER)", session)
}
@@ -2138,6 +2140,7 @@ class UDTFSuite extends TestData {
Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)")))
df.join(tf, Map("arg1" -> df("b")), Seq.empty, Seq(df("b"))).show()
df.join(tf, Map("arg1" -> df("b")), Seq.empty, Seq.empty).show()
+ succeed
}
test("single partition") {
diff --git a/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala
index 79acd450..91d43641 100644
--- a/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala
@@ -108,6 +108,7 @@ class ViewSuite extends TestData {
checkAnswer(session.table(viewName), Seq(Row(1), Row(2), Row(3)))
dropView(viewName)
+ succeed
} finally {
dropView(viewName)
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala b/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala
index 89e8cbd4..f5e4dce3 100644
--- a/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala
@@ -114,14 +114,14 @@ class WindowFramesSuite extends TestData {
.over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))),
Seq(Row(1, 1)))
- intercept[SnowflakeSQLException](
- df.select(min($"key").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect)
+ assertThrows[SnowflakeSQLException](
+ df.select(min($"key").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect())
- intercept[SnowflakeSQLException](
- df.select(min($"key").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect)
+ assertThrows[SnowflakeSQLException](
+ df.select(min($"key").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect())
- intercept[SnowflakeSQLException](
- df.select(min($"key").over(window.rangeBetween(-1, 1))).collect)
+ assertThrows[SnowflakeSQLException](
+ df.select(min($"key").over(window.rangeBetween(-1, 1))).collect())
}
test("SN - range between should accept numeric values only when bounded") {
@@ -136,13 +136,13 @@ class WindowFramesSuite extends TestData {
Row("non_numeric", "non_numeric") :: Nil)
// TODO: Add another test with eager mode enabled
- intercept[SnowflakeSQLException](
+ assertThrows[SnowflakeSQLException](
df.select(min($"value").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect())
- intercept[SnowflakeSQLException](
+ assertThrows[SnowflakeSQLException](
df.select(min($"value").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect())
- intercept[SnowflakeSQLException](
+ assertThrows[SnowflakeSQLException](
df.select(min($"value").over(window.rangeBetween(-1, 1))).collect())
}
@@ -186,5 +186,6 @@ class WindowFramesSuite extends TestData {
assert(values_1_B.contains(row.getInt(1)) && values_2_B.contains(row.getInt(2)))
}
}
+ succeed
}
}
diff --git a/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala b/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala
index 00100566..7506c34f 100644
--- a/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala
+++ b/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala
@@ -3,7 +3,7 @@ package com.snowflake.snowpark_test
import com.snowflake.snowpark.functions._
import com.snowflake.snowpark.{DataFrame, Row, TestData, Window}
import net.snowflake.client.jdbc.SnowflakeSQLException
-import org.scalatest.Matchers.the
+import org.scalatest.matchers.should.Matchers.the
import scala.reflect.ClassTag
@@ -97,6 +97,7 @@ class WindowSpecSuite extends TestData {
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1
|""".stripMargin))
+ succeed
}
test("reuse window partitionBy") {
@@ -171,7 +172,7 @@ class WindowSpecSuite extends TestData {
test("SN - window function should fail if order by clause is not specified") {
val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value")
- val e = intercept[SnowflakeSQLException](
+ assertThrows[SnowflakeSQLException](
// Here we missed .orderBy("key")!
df.select(row_number().over(Window.partitionBy($"value"))).collect())
}
@@ -329,8 +330,7 @@ class WindowSpecSuite extends TestData {
test("SN - aggregation function on invalid column") {
val df = Seq((1, "1")).toDF("key", "value")
- val e =
- intercept[SnowflakeSQLException](df.select($"key", count($"invalid").over()).collect())
+ assertThrows[SnowflakeSQLException](df.select($"key", count($"invalid").over()).collect())
}
test("SN - statistical functions") {
diff --git a/src/test/scala/com/snowflake/test/QueryTagSuite.scala b/src/test/scala/com/snowflake/test/QueryTagSuite.scala
index 8ad7f44c..760f7fa9 100644
--- a/src/test/scala/com/snowflake/test/QueryTagSuite.scala
+++ b/src/test/scala/com/snowflake/test/QueryTagSuite.scala
@@ -21,5 +21,6 @@ class QueryTagSuite extends SNTestBase {
res.next()
assert(res.getString("value").equals(someTag))
statement.close()
+ succeed
}
}