diff --git a/pom.xml b/pom.xml index 7ba2f364..3390bf72 100644 --- a/pom.xml +++ b/pom.xml @@ -34,9 +34,9 @@ 1.8 1.8 UTF-8 - 2.12.18 - 2.12 - 4.2.0 + 2.13.13 + 2.13 + 4.20.5 3.16.0 ${scala.compat.version} Snowpark ${project.version} @@ -134,7 +134,7 @@ org.scalatest scalatest_${scala.compat.version} - 3.0.5 + 3.2.18 test @@ -496,7 +496,7 @@ test-coverage - 2.12.15 + 2.13.13 diff --git a/src/main/scala/com/snowflake/snowpark/DataFrame.scala b/src/main/scala/com/snowflake/snowpark/DataFrame.scala index 91f0021b..107a06e4 100644 --- a/src/main/scala/com/snowflake/snowpark/DataFrame.scala +++ b/src/main/scala/com/snowflake/snowpark/DataFrame.scala @@ -1797,7 +1797,10 @@ class DataFrame private[snowpark] ( * @param firstArg The first argument to pass to the specified table function. * @param remaining A list of any additional arguments for the specified table function. */ - def join(func: TableFunction, firstArg: Column, remaining: Column*): DataFrame = + def join( + func: com.snowflake.snowpark.TableFunction, + firstArg: Column, + remaining: Column*): DataFrame = join(func, firstArg +: remaining) /** @@ -1824,7 +1827,7 @@ class DataFrame private[snowpark] ( * object or an object that you create from the [[TableFunction]] class. * @param args A list of arguments to pass to the specified table function. */ - def join(func: TableFunction, args: Seq[Column]): DataFrame = + def join(func: com.snowflake.snowpark.TableFunction, args: Seq[Column]): DataFrame = joinTableFunction(func.call(args: _*), None) /** @@ -1854,7 +1857,7 @@ class DataFrame private[snowpark] ( * @param orderBy A list of columns ordered by. */ def join( - func: TableFunction, + func: com.snowflake.snowpark.TableFunction, args: Seq[Column], partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = @@ -1892,7 +1895,7 @@ class DataFrame private[snowpark] ( * Some functions, like `flatten`, have named parameters. * Use this map to specify the parameter names and their corresponding values. */ - def join(func: TableFunction, args: Map[String, Column]): DataFrame = + def join(func: com.snowflake.snowpark.TableFunction, args: Map[String, Column]): DataFrame = joinTableFunction(func.call(args), None) /** @@ -1929,7 +1932,7 @@ class DataFrame private[snowpark] ( * @param orderBy A list of columns ordered by. */ def join( - func: TableFunction, + func: com.snowflake.snowpark.TableFunction, args: Map[String, Column], partitionBy: Seq[Column], orderBy: Seq[Column]): DataFrame = @@ -2366,7 +2369,7 @@ class DataFrame private[snowpark] ( } } lines.append(value.substring(startIndex)) - lines + lines.toSeq } def convertValueToString(value: Any): String = @@ -2501,7 +2504,7 @@ class DataFrame private[snowpark] ( * and view name. */ def createOrReplaceView(multipartIdentifier: java.util.List[String]): Unit = - createOrReplaceView(multipartIdentifier.asScala) + createOrReplaceView(multipartIdentifier.asScala.toSeq) /** * Creates a temporary view that returns the same results as this DataFrame. @@ -2564,7 +2567,7 @@ class DataFrame private[snowpark] ( * view name. */ def createOrReplaceTempView(multipartIdentifier: java.util.List[String]): Unit = - createOrReplaceTempView(multipartIdentifier.asScala) + createOrReplaceTempView(multipartIdentifier.asScala.toSeq) private def doCreateOrReplaceView(viewName: String, viewType: ViewType): Unit = { session.conn.telemetry.reportActionCreateOrReplaceView() diff --git a/src/main/scala/com/snowflake/snowpark/Row.scala b/src/main/scala/com/snowflake/snowpark/Row.scala index a1dc5aef..8617eb7b 100644 --- a/src/main/scala/com/snowflake/snowpark/Row.scala +++ b/src/main/scala/com/snowflake/snowpark/Row.scala @@ -350,7 +350,7 @@ class Row protected (values: Array[Any]) extends Serializable { * @since 1.13.0 * @group getter */ - def getSeq[T](index: Int): Seq[T] = { + def getSeq[T: ClassTag](index: Int): Seq[T] = { val result = getAs[Array[_]](index) result.map { case x: T => x diff --git a/src/main/scala/com/snowflake/snowpark/Session.scala b/src/main/scala/com/snowflake/snowpark/Session.scala index 633c8e42..b590b9a5 100644 --- a/src/main/scala/com/snowflake/snowpark/Session.scala +++ b/src/main/scala/com/snowflake/snowpark/Session.scala @@ -444,7 +444,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log * @since 0.2.0 */ def table(multipartIdentifier: java.util.List[String]): Updatable = - table(multipartIdentifier.asScala) + table(multipartIdentifier.asScala.toSeq) /** * Returns an Updatable that points to the specified table. @@ -497,7 +497,10 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log * @param firstArg the first function argument of the given table function. * @param remaining all remaining function arguments. */ - def tableFunction(func: TableFunction, firstArg: Column, remaining: Column*): DataFrame = + def tableFunction( + func: com.snowflake.snowpark.TableFunction, + firstArg: Column, + remaining: Column*): DataFrame = tableFunction(func, firstArg +: remaining) /** @@ -525,7 +528,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log * referred from the built-in list from tableFunctions. * @param args function arguments of the given table function. */ - def tableFunction(func: TableFunction, args: Seq[Column]): DataFrame = { + def tableFunction(func: com.snowflake.snowpark.TableFunction, args: Seq[Column]): DataFrame = { // Use df.join to apply function result if args contains a DF column val sourceDFs = args.flatMap(_.expr.sourceDFs) if (sourceDFs.isEmpty) { @@ -569,7 +572,9 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log * Some functions, like flatten, have named parameters. * use this map to assign values to the corresponding parameters. */ - def tableFunction(func: TableFunction, args: Map[String, Column]): DataFrame = { + def tableFunction( + func: com.snowflake.snowpark.TableFunction, + args: Map[String, Column]): DataFrame = { // Use df.join to apply function result if args contains a DF column val sourceDFs = args.values.flatMap(_.expr.sourceDFs) if (sourceDFs.isEmpty) { @@ -615,9 +620,9 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log def tableFunction(func: Column): DataFrame = { func.expr match { case TFunction(funcName, args) => - tableFunction(TableFunction(funcName), args.map(Column(_))) + tableFunction(com.snowflake.snowpark.TableFunction(funcName), args.map(Column(_))) case NamedArgumentsTableFunction(funcName, argMap) => - tableFunction(TableFunction(funcName), argMap.map { + tableFunction(com.snowflake.snowpark.TableFunction(funcName), argMap.map { case (key, value) => key -> Column(value) }) case _ => throw ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT() diff --git a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala index 58bd69e7..d18d275c 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala @@ -319,7 +319,7 @@ object JavaUtils { def stringArrayToStringSeq(arr: Array[String]): Seq[String] = arr def objectListToAnySeq(input: java.util.List[java.util.List[Object]]): Seq[Seq[Any]] = - input.asScala.map(list => list.asScala) + input.asScala.map(list => list.asScala.toSeq).toSeq def registerUDF( udfRegistration: UDFRegistration, @@ -360,18 +360,26 @@ object JavaUtils { map.asScala.toMap def javaMapToScalaWithVariantConversion(map: java.util.Map[_, _]): Map[Any, Any] = - map.asScala.map { - case (key, value: com.snowflake.snowpark_java.types.Variant) => - key -> InternalUtils.toScalaVariant(value) - case (key, value) => key -> value - }.toMap + map.asScala + .map((e: (Any, Any)) => { + (e) match { + case (key, value: com.snowflake.snowpark_java.types.Variant) => + key -> InternalUtils.toScalaVariant(value) + case (key, value) => key -> value + } + }) + .toMap def scalaMapToJavaWithVariantConversion(map: Map[_, _]): java.util.Map[Object, Object] = - map.map { - case (key, value: com.snowflake.snowpark.types.Variant) => - key.asInstanceOf[Object] -> InternalUtils.createVariant(value) - case (key, value) => key.asInstanceOf[Object] -> value.asInstanceOf[Object] - }.asJava + map + .map((e: (Any, Any)) => { + (e) match { + case (key, value: com.snowflake.snowpark.types.Variant) => + key.asInstanceOf[Object] -> InternalUtils.createVariant(value) + case (key, value) => key.asInstanceOf[Object] -> value.asInstanceOf[Object] + } + }) + .asJava def serialize(obj: Any): Array[Byte] = { val bos = new ByteArrayOutputStream() diff --git a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala index 92728eaf..e920ba31 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/ServerConnection.scala @@ -52,6 +52,7 @@ import net.snowflake.client.jdbc.internal.apache.arrow.vector.util.{ import java.util import scala.collection.mutable +import scala.reflect.classTag import scala.reflect.runtime.universe.TypeTag import scala.collection.JavaConverters._ @@ -317,7 +318,7 @@ private[snowpark] class ServerConnection( private[snowpark] def resultSetToRows(statement: Statement): Array[Row] = withValidConnection { val iterator = resultSetToIterator(statement)._1 - val buff = mutable.ArrayBuilder.make[Row]() + val buff = mutable.ArrayBuilder.make[Row](classTag[Row]) while (iterator.hasNext) { buff += iterator.next() } diff --git a/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala b/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala index acf9b62e..ed8a9a36 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/UDXRegistrationHandler.scala @@ -385,7 +385,7 @@ class UDXRegistrationHandler(session: Session) extends Logging { if (actionID <= session.getLastCanceledID) { throw ErrorMessage.MISC_QUERY_IS_CANCELLED() } - allUrls + allUrls.toSeq } (allImports, targetJarStageLocation) } diff --git a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala index d464566c..4415afd0 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/Utils.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/Utils.scala @@ -27,9 +27,9 @@ object Utils extends Logging { // Define the compat scala version instead of reading from property file // because it fails to read the property file in some environment such as // VSCode worksheet. - val ScalaCompatVersion: String = "2.12" + val ScalaCompatVersion: String = "2.13" // Minimum minor version. We require version to be greater than 2.12.9 - val ScalaMinimumMinorVersion: String = "2.12.9" + val ScalaMinimumMinorVersion: String = "2.13.9" // Minimum GS version for us to identify as Snowpark client val MinimumGSVersionForSnowparkClientType: String = "5.20.0" diff --git a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala index a3218758..381d079e 100644 --- a/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala +++ b/src/main/scala/com/snowflake/snowpark/internal/analyzer/SnowflakePlan.scala @@ -74,9 +74,9 @@ class SnowflakePlan( } val supportAsyncMode = subqueryPlans.forall(_.supportAsyncMode) SnowflakePlan( - preQueries :+ queries.last, + (preQueries :+ queries.last).toSeq, newSchemaQuery, - newPostActions, + newPostActions.toSeq, session, sourcePlan, supportAsyncMode) diff --git a/src/main/scala/com/snowflake/snowpark/types/Variant.scala b/src/main/scala/com/snowflake/snowpark/types/Variant.scala index 28ce5b67..4f00f302 100644 --- a/src/main/scala/com/snowflake/snowpark/types/Variant.scala +++ b/src/main/scala/com/snowflake/snowpark/types/Variant.scala @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.node.{ArrayNode, JsonNodeFactory, ObjectNo import java.math.{BigDecimal => JavaBigDecimal, BigInteger => JavaBigInteger} import java.sql.{Date, Time, Timestamp} import java.util.{List => JavaList, Map => JavaMap} +import scala.jdk.FunctionConverters._ import scala.collection.JavaConverters._ import Variant._ import org.apache.commons.codec.binary.{Base64, Hex} @@ -225,7 +226,7 @@ class Variant private[snowpark] ( * * @since 0.2.0 */ - def this(list: JavaList[Object]) = this(list.asScala) + def this(list: JavaList[Object]) = this(list.asScala.toSeq) /** * Creates a Variant from array @@ -244,7 +245,11 @@ class Variant private[snowpark] ( { def mapToNode(map: JavaMap[Object, Object]): ObjectNode = { val result = MAPPER.createObjectNode() - map.keySet().forEach(key => result.set(key.toString, objectToJsonNode(map.get(key)))) + val consumer = + (key: Object) => result.set(key.toString, objectToJsonNode(map.get(key))) + map + .keySet() + .forEach(consumer.asJavaConsumer) result } obj match { diff --git a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala index c0f1ed3e..85bc310d 100644 --- a/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala +++ b/src/test/scala/com/snowflake/code_verification/JavaScalaAPISuite.scala @@ -1,7 +1,7 @@ package com.snowflake.code_verification import com.snowflake.snowpark.{CodeVerification, DataFrame} -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} // verify API Java and Scala API contain same functions @CodeVerification diff --git a/src/test/scala/com/snowflake/code_verification/PomSuite.scala b/src/test/scala/com/snowflake/code_verification/PomSuite.scala index 6dca6810..74d6d331 100644 --- a/src/test/scala/com/snowflake/code_verification/PomSuite.scala +++ b/src/test/scala/com/snowflake/code_verification/PomSuite.scala @@ -1,7 +1,7 @@ package com.snowflake.code_verification import com.snowflake.snowpark.CodeVerification -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} import scala.collection.mutable diff --git a/src/test/scala/com/snowflake/perf/PerfBase.scala b/src/test/scala/com/snowflake/perf/PerfBase.scala index 6969249a..508788d5 100644 --- a/src/test/scala/com/snowflake/perf/PerfBase.scala +++ b/src/test/scala/com/snowflake/perf/PerfBase.scala @@ -77,6 +77,7 @@ trait PerfBase extends SNTestBase { test(testName) { try { writeResult(testName, timer(func)) + succeed } catch { case ex: Exception => writeResult(testName, -1.0) // -1.0 if failed @@ -84,7 +85,10 @@ trait PerfBase extends SNTestBase { } } } else { - ignore(testName)(func) + ignore(testName) { + func + succeed + } } } } diff --git a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala index a28d167a..37bac0d4 100644 --- a/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/APIInternalSuite.scala @@ -18,6 +18,7 @@ import com.snowflake.snowpark.internal.analyzer.{ import com.snowflake.snowpark.types._ import net.snowflake.client.core.SFSessionProperty import net.snowflake.client.jdbc.SnowflakeSQLException +import org.scalatest.Assertion import java.nio.file.Files import java.sql.{Date, Timestamp} @@ -26,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await import scala.concurrent.duration._ import scala.util.Random +import scala.language.postfixOps class APIInternalSuite extends TestData { private val userSchema: StructType = StructType( @@ -578,10 +580,11 @@ class APIInternalSuite extends TestData { testWithAlteredSessionParameter(() => { import session.implicits._ val schema = StructType(Seq(StructField("ID", LongType))) - val largeData = new ArrayBuffer[Row]() + val largeDataBuf = new ArrayBuffer[Row]() for (i <- 0 to 1024) { - largeData.append(Row(i.toLong)) + largeDataBuf.append(Row(i.toLong)) } + val largeData = largeDataBuf.toSeq // With specific schema var df = session.createDataFrame(largeData, schema) assert(df.snowflakePlan.queries.size == 3) @@ -596,7 +599,7 @@ class APIInternalSuite extends TestData { for (i <- 0 to 1024) { inferData.append(i.toLong) } - df = inferData.toDF("id2") + df = inferData.toSeq.toDF("id2") assert(df.snowflakePlan.queries.size == 3) assert(df.snowflakePlan.queries(0).sql.trim().startsWith("CREATE SCOPED TEMPORARY TABLE")) assert(df.snowflakePlan.queries(1).sql.trim().startsWith("INSERT INTO")) @@ -823,6 +826,7 @@ class APIInternalSuite extends TestData { val (rows, meta) = session.conn.getResultAndMetadata(session.sql(query).snowflakePlan) assert(rows.length == 0 || rows(0).length == meta.size) } + succeed } // reader @@ -895,7 +899,7 @@ class APIInternalSuite extends TestData { assert(ex2.errorCode.equals("0321")) } - def checkExecuteAndGetQueryId(df: DataFrame): Unit = { + def checkExecuteAndGetQueryId(df: DataFrame): Assertion = { val query = Query.resultScanQuery(df.executeAndGetQueryId()) val res = query.runQueryGetResult(session.conn, mutable.HashMap.empty[String, String], false) res.attributes @@ -907,7 +911,7 @@ class APIInternalSuite extends TestData { checkExecuteAndGetQueryIdWithStatementParameter(df) } - private def checkExecuteAndGetQueryIdWithStatementParameter(df: DataFrame): Unit = { + private def checkExecuteAndGetQueryIdWithStatementParameter(df: DataFrame): Assertion = { val testQueryTagValue = s"test_query_tag_${Random.nextLong().abs}" val queryId = df.executeAndGetQueryId(Map("QUERY_TAG" -> testQueryTagValue)) val rows = session @@ -1007,7 +1011,7 @@ class APIInternalSuite extends TestData { largeData.append( Row(1025, null, null, null, null, null, null, null, null, null, null, null, null)) - val df = session.createDataFrame(largeData, schema) + val df = session.createDataFrame(largeData.toSeq, schema) checkExecuteAndGetQueryId(df) // Statement parameters are applied for all queries. @@ -1039,6 +1043,7 @@ class APIInternalSuite extends TestData { // case 2: test int/boolean parameter multipleQueriesDF1.executeAndGetQueryId( Map("STATEMENT_TIMEOUT_IN_SECONDS" -> 100, "USE_CACHED_RESULT" -> false)) + succeed } test("VariantTypes.getType") { @@ -1052,7 +1057,7 @@ class APIInternalSuite extends TestData { assert(Variant.VariantTypes.getType("Timestamp") == Variant.VariantTypes.Timestamp) assert(Variant.VariantTypes.getType("Array") == Variant.VariantTypes.Array) assert(Variant.VariantTypes.getType("Object") == Variant.VariantTypes.Object) - intercept[Exception] { Variant.VariantTypes.getType("not_exist_type") } + assertThrows[Exception] { Variant.VariantTypes.getType("not_exist_type") } } test("HasCachedResult doesn't cache again") { diff --git a/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala b/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala index 2509176a..b8363191 100644 --- a/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/DropTempObjectsSuite.scala @@ -109,16 +109,16 @@ class DropTempObjectsSuite extends SNTestBase { TempObjectType.Table, "db.schema.tempName1", TempType.Temporary) - assertTrue(session.getTempObjectMap.contains("db.schema.tempName1")) + assert(session.getTempObjectMap.contains("db.schema.tempName1")) session.recordTempObjectIfNecessary( TempObjectType.Table, "db.schema.tempName2", TempType.ScopedTemporary) - assertFalse(session.getTempObjectMap.contains("db.schema.tempName2")) + assert(!session.getTempObjectMap.contains("db.schema.tempName2")) session.recordTempObjectIfNecessary( TempObjectType.Table, "db.schema.tempName3", TempType.Permanent) - assertFalse(session.getTempObjectMap.contains("db.schema.tempName3")) + assert(!session.getTempObjectMap.contains("db.schema.tempName3")) } } diff --git a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala index 937b93e6..714fceb1 100644 --- a/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ErrorMessageSuite.scala @@ -6,7 +6,7 @@ import com.snowflake.snowpark.internal.ParameterUtils.{ MIN_REQUEST_TIMEOUT_IN_SECONDS, SnowparkRequestTimeoutInSeconds } -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} class ErrorMessageSuite extends FunSuite { diff --git a/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala b/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala index 1ada9af9..a78fab76 100644 --- a/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ExpressionAndPlanNodeSuite.scala @@ -250,6 +250,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { emptyChecker(CurrentRow) emptyChecker(UnspecifiedFrame) binaryChecker(SpecifiedWindowFrame(RowFrame, _, _)) + succeed } test("star children and dependent columns") { @@ -475,6 +476,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { leafAnalyzerChecker(CurrentRow) leafAnalyzerChecker(UnspecifiedFrame) binaryAnalyzerChecker(SpecifiedWindowFrame(RowFrame, _, _)) + succeed } test("star - analyze") { @@ -489,6 +491,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { assert(exp.analyze(x => x) == exp) assert(exp.analyze(_ => att2) == att2) leafAnalyzerChecker(Star(Seq.empty)) + succeed } test("WindowSpecDefinition - analyze") { @@ -988,6 +991,7 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { assert(key.name == "\"COL3\"") assert(value.name == "\"COL3\"") } + succeed } test("TableDelete - Analyzer") { @@ -1125,5 +1129,6 @@ class ExpressionAndPlanNodeSuite extends SNTestBase { leafSimplifierChecker( SnowflakePlan(Seq.empty, "222", session, None, supportAsyncMode = false)) + succeed } } diff --git a/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala b/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala index c3007066..eb3df273 100644 --- a/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/FatJarBuilderSuite.scala @@ -5,7 +5,7 @@ import java.util.jar.{JarFile, JarOutputStream} import java.util.zip.ZipException import com.snowflake.snowpark.internal.{FatJarBuilder, JavaCodeCompiler} -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} import scala.collection.mutable.ArrayBuffer import scala.util.Random diff --git a/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala b/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala index fbd6101d..04d1cefd 100644 --- a/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala +++ b/src/test/scala/com/snowflake/snowpark/JavaAPISuite.scala @@ -1,6 +1,6 @@ package com.snowflake.snowpark -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} import com.snowflake.snowpark_test._ import java.io.ByteArrayOutputStream diff --git a/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala b/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala index 4e3faa54..a347ecd8 100644 --- a/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/JavaCodeCompilerSuite.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.{InMemoryClassObject, JavaCodeCompiler, UDFClassPath} -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} class JavaCodeCompilerSuite extends FunSuite { diff --git a/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala b/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala index d3a31d9e..6ca97dd3 100644 --- a/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/LoggingSuite.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.Logging -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} class LoggingSuite extends FunSuite { diff --git a/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala b/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala index 86d43968..92f1acc4 100644 --- a/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/NewColumnReferenceSuite.scala @@ -3,6 +3,7 @@ package com.snowflake.snowpark import com.snowflake.snowpark.internal.ParameterUtils import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.types.IntegerType +import org.scalatest.Assertion import scala.language.implicitConversions @@ -219,7 +220,9 @@ class NewColumnReferenceSuite extends SNTestBase { case class TestInternalAlias(name: String) extends TestColumnName implicit def stringToOriginalName(name: String): TestOriginalName = TestOriginalName(name) - private def verifyOutputName(output: Seq[Attribute], columnNames: Seq[TestColumnName]): Unit = { + private def verifyOutputName( + output: Seq[Attribute], + columnNames: Seq[TestColumnName]): Assertion = { assert(output.size == columnNames.size) assert(output.map(_.name).zip(columnNames).forall { case (name, TestOriginalName(n)) => name == quoteName(n) @@ -280,6 +283,7 @@ class NewColumnReferenceSuite extends SNTestBase { verifyUnaryNode(child => TableUpdate("a", Map.empty, None, Some(child))) verifyUnaryNode(child => SnowflakeCreateTable("a", SaveMode.Append, Some(child))) verifyBinaryNode((plan1, plan2) => SimplifiedUnion(Seq(plan1, plan2))) + succeed } private val project1 = Project(Seq(Attribute("a", IntegerType)), Range(1, 1, 1)) diff --git a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala index 75365e5b..d461ad5f 100644 --- a/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ParameterSuite.scala @@ -77,6 +77,7 @@ class ParameterSuite extends SNTestBase { // no need to verify PKCS#1 format key additionally, // since all Github Action tests use PKCS#1 key to authenticate with Snowflake server. ParameterUtils.parsePrivateKey(generatePKCS8Key()) + succeed } private def generatePKCS8Key(): String = { diff --git a/src/test/scala/com/snowflake/snowpark/ReplSuite.scala b/src/test/scala/com/snowflake/snowpark/ReplSuite.scala index cba06aa4..eeb859e9 100644 --- a/src/test/scala/com/snowflake/snowpark/ReplSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ReplSuite.scala @@ -1,15 +1,14 @@ package com.snowflake.snowpark -import java.io.{BufferedReader, OutputStreamWriter, StringReader} +import java.io.{BufferedReader, OutputStreamWriter, StringReader, PrintWriter => JPrintWriter} import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths, StandardCopyOption} - import com.snowflake.snowpark.internal.Utils import scala.tools.nsc.Settings -import scala.tools.nsc.interpreter._ import scala.tools.nsc.util.stringFromStream import scala.sys.process._ +import scala.tools.nsc.interpreter.shell.{ILoop, ShellConfig} @UDFTest class ReplSuite extends TestData { @@ -48,7 +47,6 @@ class ReplSuite extends TestData { Console.withOut(outputStream) { val input = new BufferedReader(new StringReader(preLoad + code)) val output = new JPrintWriter(new OutputStreamWriter(outputStream)) - val repl = new ILoop(input, output) val settings = new Settings() if (inMemory) { settings.processArgumentString("-Yrepl-class-based") @@ -56,7 +54,8 @@ class ReplSuite extends TestData { settings.processArgumentString("-Yrepl-class-based -Yrepl-outdir repl_classes") } settings.classpath.value = sys.props("java.class.path") - repl.process(settings) + val repl = new ILoop(ShellConfig(settings), input, output) + repl.run(settings) } }.replaceAll("scala> ", "") } diff --git a/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala b/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala index ebb5f9c4..e13261a2 100644 --- a/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ResultAttributesSuite.scala @@ -40,6 +40,7 @@ class ResultAttributesSuite extends SNTestBase { val attribute = getAttributesWithTypes(tableName, integers) assert(attribute.length == integers.length) integers.indices.foreach(index => assert(attribute(index).dataType == LongType)) + succeed } test("float data type") { @@ -47,6 +48,7 @@ class ResultAttributesSuite extends SNTestBase { val attribute = getAttributesWithTypes(tableName, floats) assert(attribute.length == floats.length) floats.indices.foreach(index => assert(attribute(index).dataType == DoubleType)) + succeed } test("string data types") { @@ -54,6 +56,7 @@ class ResultAttributesSuite extends SNTestBase { val attribute = getAttributesWithTypes(tableName, strings) assert(attribute.length == strings.length) strings.indices.foreach(index => assert(attribute(index).dataType == StringType)) + succeed } test("binary data types") { @@ -61,6 +64,7 @@ class ResultAttributesSuite extends SNTestBase { val attribute = getAttributesWithTypes(tableName, binaries) assert(attribute.length == binaries.length) binaries.indices.foreach(index => assert(attribute(index).dataType == BinaryType)) + succeed } test("logical data type") { @@ -69,6 +73,7 @@ class ResultAttributesSuite extends SNTestBase { assert(attributes.length == 1) assert(attributes.head.dataType == BooleanType) dropTable(tableName) + succeed } test("date & time data type") { @@ -83,6 +88,7 @@ class ResultAttributesSuite extends SNTestBase { val attribute = getAttributesWithTypes(tableName, dates.map(_._1)) assert(attribute.length == dates.length) dates.indices.foreach(index => assert(attribute(index).dataType == dates(index)._2)) + succeed } test("semi-structured data types") { @@ -106,5 +112,6 @@ class ResultAttributesSuite extends SNTestBase { index => assert(attribute(index).dataType == ArrayType(StringType))) + succeed } } diff --git a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala index 1144b207..407e34c1 100644 --- a/src/test/scala/com/snowflake/snowpark/SNTestBase.scala +++ b/src/test/scala/com/snowflake/snowpark/SNTestBase.scala @@ -8,12 +8,13 @@ import com.snowflake.snowpark.internal.{ParameterUtils, ServerConnection, UDFCla import com.snowflake.snowpark.types._ import com.snowflake.snowpark_test.TestFiles import org.mockito.Mockito.{doReturn, spy, when} -import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.{Assertion, Assertions, BeforeAndAfterAll} +import org.scalatest.funsuite.{AsyncFunSuite => FunSuite} import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global +import scala.language.postfixOps trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with SnowTestFiles { @@ -100,7 +101,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S } } - def checkAnswer(df1: DataFrame, df2: DataFrame, sort: Boolean): Unit = { + def checkAnswer(df1: DataFrame, df2: DataFrame, sort: Boolean): Assertion = { if (sort) { assert( TestUtils.compare(df1.collect().sortBy(_.toString), df2.collect().sortBy(_.toString))) @@ -109,14 +110,15 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S } } - def checkAnswer(df: DataFrame, result: Row): Unit = - checkResult(df.collect(), Seq(result), false) + def checkAnswer(df: DataFrame, result: Row): Assertion = + checkResult(df.collect(), Seq(result), sort = false) - def checkAnswer(df: DataFrame, result: Seq[Row], sort: Boolean = true): Unit = + def checkAnswer(df: DataFrame, result: Seq[Row], sort: Boolean = true): Assertion = checkResult(df.collect(), result, sort) - def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Unit = + def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Assertion = TestUtils.checkResult(result, expected, sort) + def checkResultIterator(result: Iterator[Row], expected: Seq[Row], sort: Boolean = true): Unit = checkResult(result.toArray, expected, sort) @@ -171,7 +173,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S thunk: => T, parameter: String, value: String, - skipIfParamNotExist: Boolean = false): Unit = { + skipIfParamNotExist: Boolean = false): Assertion = { var parameterNotExist = false try { session.runQuery(s"alter session set $parameter = $value") @@ -187,6 +189,7 @@ trait SNTestBase extends FunSuite with BeforeAndAfterAll with SFTestUtils with S // best effort to unset the parameter. session.runQuery(s"alter session unset $parameter") } + succeed } def testWithTimezone[T](thunk: => T, timezone: String): T = { diff --git a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala index 5269f4b2..6b299064 100644 --- a/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/ServerConnectionSuite.scala @@ -68,7 +68,8 @@ class ServerConnectionSuite extends SNTestBase { for (i <- 0 to 1024) { largeData.append(Row(i)) } - val df2 = session.createDataFrame(largeData, StructType(Seq(StructField("ID", LongType)))) + val df2 = + session.createDataFrame(largeData.toSeq, StructType(Seq(StructField("ID", LongType)))) val ex2 = intercept[SnowparkClientException] { session.conn.executeAsync(df2.snowflakePlan) } diff --git a/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala b/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala index e7c857a3..c4a40285 100644 --- a/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/SnowflakePlanSuite.scala @@ -126,7 +126,8 @@ class SnowflakePlanSuite extends SNTestBase { for (i <- 0 to 1024) { largeData.append(Row(i)) } - val df2 = session.createDataFrame(largeData, StructType(Seq(StructField("ID", LongType)))) + val df2 = + session.createDataFrame(largeData.toSeq, StructType(Seq(StructField("ID", LongType)))) assert(!df2.snowflakePlan.supportAsyncMode && !df2.clone.snowflakePlan.supportAsyncMode) assert(!session.sql(" put file").snowflakePlan.supportAsyncMode) assert(!session.sql("get file ").snowflakePlan.supportAsyncMode) diff --git a/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala b/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala index 4864c38f..1bc2b98f 100644 --- a/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/SnowparkSFConnectionHandlerSuite.scala @@ -1,6 +1,6 @@ package com.snowflake.snowpark -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} import com.snowflake.snowpark.internal.SnowparkSFConnectionHandler class SnowparkSFConnectionHandlerSuite extends FunSuite { diff --git a/src/test/scala/com/snowflake/snowpark/TestUtils.scala b/src/test/scala/com/snowflake/snowpark/TestUtils.scala index e1abedff..7fa3ea36 100644 --- a/src/test/scala/com/snowflake/snowpark/TestUtils.scala +++ b/src/test/scala/com/snowflake/snowpark/TestUtils.scala @@ -18,7 +18,7 @@ import com.snowflake.snowpark.internal.UDFClassPath.getPathForClass import com.snowflake.snowpark.internal.analyzer.{quoteName, quoteNameWithoutUpperCasing} import com.snowflake.snowpark.types._ import com.snowflake.snowpark_java.types.{InternalUtils, StructType => JavaStructType} -import org.scalatest.{BeforeAndAfterAll, Tag} +import org.scalatest.{Assertion, Assertions, BeforeAndAfterAll, Tag} import java.util.{Locale, Properties} import com.snowflake.snowpark.Session.loadConfFromFile @@ -101,7 +101,7 @@ object TestUtils extends Logging { s"insert into $name values ${data.map("(" + _.toString + ")").mkString(",")}") def insertIntoTable(name: String, data: java.util.List[Object], session: Session): Unit = - insertIntoTable(name, data.asScala.map(_.toString), session) + insertIntoTable(name, data.asScala.map(_.toString).toSeq, session) def uploadFileToStage( stageName: String, @@ -275,17 +275,17 @@ object TestUtils extends Logging { res } - def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Unit = { + def checkResult(result: Array[Row], expected: Seq[Row], sort: Boolean = true): Assertion = { val sorted = if (sort) result.sortBy(_.toString) else result val sortedExpected = if (sort) expected.sortBy(_.toString) else expected - assert( + Assertions.assert( compare(sorted, sortedExpected.toArray[Row]), s"${sorted.map(_.toString).mkString("[", ", ", "]")} != " + s"${sortedExpected.map(_.toString).mkString("[", ", ", "]")}") } - def checkResult(result: Array[Row], expected: java.util.List[Row], sort: Boolean): Unit = - checkResult(result, expected.asScala, sort) + def checkResult(result: Array[Row], expected: java.util.List[Row], sort: Boolean): Assertion = + checkResult(result, expected.asScala.toSeq, sort) def runQueryInSession(session: Session, sql: String): Unit = session.runQuery(sql) diff --git a/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala index 3c3f388d..d4218287 100644 --- a/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UDFClasspathSuite.scala @@ -48,6 +48,7 @@ class UDFClasspathSuite extends SNTestBase { assert(mockSession.listFilesInStage(stageName1).size == jarInClassPath.size) // Assert that no dependencies are uploaded in second time verify(mockSession, never()).doUpload(any(), any()) + succeed } test("Test that udf function's class path is automatically added") { @@ -58,6 +59,7 @@ class UDFClasspathSuite extends SNTestBase { udfR.registerUDF(Some(func), _toUdf((a: Int) => a + a), None) verify(mockSession, atLeastOnce()) .addDependency(path) + succeed } test("Test that snowpark jar is NOT uploaded if stage path is available") { @@ -85,6 +87,7 @@ class UDFClasspathSuite extends SNTestBase { .addDependency(expectedPath) // createJavaUDF should be only invoked once verify(udfR, times(1)).createJavaUDF(any(), any(), any(), any(), any(), any(), any()) + succeed } test( @@ -106,6 +109,7 @@ class UDFClasspathSuite extends SNTestBase { } // createJavaUDF should be invoked twice as it is retried after fixing classpath verify(udfR, times(2)).createJavaUDF(any(), any(), any(), any(), any(), any(), any()) + succeed } test("Test for getPathForClass") { diff --git a/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala index 9a1eae3a..9a3b401c 100644 --- a/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UDFInternalSuite.scala @@ -48,6 +48,7 @@ class UDFInternalSuite extends TestData { } verify(mockSession, times(1)).removeDependency(path) verify(mockSession, times(1)).addPackage("com.snowflake:snowpark:latest") + succeed } test("Test permanent udf not failing back to upload jar", JavaStoredProcExclude) { @@ -82,6 +83,7 @@ class UDFInternalSuite extends TestData { } verify(mockSession, times(1)).removeDependency(path) verify(mockSession, times(1)).addPackage("com.snowflake:snowpark:latest") + succeed } test("Test add version logic", JavaStoredProcExclude) { @@ -103,6 +105,7 @@ class UDFInternalSuite extends TestData { val path = UDFClassPath.getPathForClass(classOf[com.snowflake.snowpark.Session]).get verify(mockSession, never()).addDependency(path) verify(mockSession, times(1)).addPackage(Utils.clientPackageName) + succeed } test("Confirm jar files to be uploaded to expected location", JavaStoredProcExclude) { @@ -170,6 +173,7 @@ class UDFInternalSuite extends TestData { mockSession1.udf.registerPermanent(funcName1, udf, stageName1) mockSession2.udf.registerPermanent(funcName2, udf, stageName1) verify(mockSession2, never()).doUpload(any(), any()) + succeed } finally { session.runQuery(s"drop function if exists $funcName1(INT)") diff --git a/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala b/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala index 1c0984b6..9a5a30ee 100644 --- a/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UDFRegistrationSuite.scala @@ -10,6 +10,7 @@ import scala.reflect.internal.util.BatchSourceFile import scala.reflect.io.{AbstractFile, VirtualDirectory} import scala.tools.nsc.GenericRunnerSettings import scala.tools.nsc.interpreter.IMain +import scala.tools.nsc.interpreter.shell.ReplReporterImpl import scala.util.Random @UDFTest @@ -107,7 +108,7 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils { val targetDir = Files.createTempDirectory(s"snowpark_test_target_") settings.Yreploutdir.value = targetDir.toFile.getAbsolutePath } - val interpreter: IMain = new IMain(settings) + val interpreter: IMain = new IMain(settings, new ReplReporterImpl(settings)) interpreter.compileSources(new BatchSourceFile(AbstractFile.getFile(new File(fileName)))) interpreter.classLoader.loadClass(s"$packageName.$className") @@ -123,6 +124,7 @@ class UDFRegistrationSuite extends SNTestBase with FileUtils { val onDiskName = s"DynamicCompile${Random.nextInt().abs}" val onDiskClass = generateDynamicClass(packageName, onDiskName, false) session.udf.handler.addClassToDependencies(onDiskClass) + succeed } test("ls file") { diff --git a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala index 2067212f..854f2adb 100644 --- a/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark/UtilsSuite.scala @@ -23,6 +23,8 @@ import java.lang.{ Short => JavaShort } import net.snowflake.client.jdbc.SnowflakeSQLException +import org.scalatest.Assertion +import org.scalatest.Assertions.succeed import scala.collection.mutable.ArrayBuffer @@ -40,7 +42,7 @@ class UtilsSuite extends SNTestBase { Seq("random.jar", "snow-0.3.0.jar", "snowpark", "snowpark.tar.gz").foreach(jarName => { assert(!Utils.isSnowparkJar(jarName)) }) - + succeed } test("Logging") { @@ -50,6 +52,7 @@ class UtilsSuite extends SNTestBase { test("utils.version") { // Stored Proc jdbc relies on Utils.Version. This test will prevent changes to this method. Utils.getClass.getMethod("Version") + succeed } test("test mask secrets") { @@ -191,10 +194,11 @@ class UtilsSuite extends SNTestBase { double: java.lang.Double) test("Non-nullable types") { - TypeToSchemaConverter - .inferSchema[(Boolean, Byte, Short, Int, Long, Float, Double)]() - .treeString(0) == - """root + assert( + TypeToSchemaConverter + .inferSchema[(Boolean, Byte, Short, Int, Long, Float, Double)]() + .treeString(0) == + """root | |--_1: Boolean (nullable = false) | |--_2: Byte (nullable = false) | |--_3: Short (nullable = false) @@ -202,32 +206,33 @@ class UtilsSuite extends SNTestBase { | |--_5: Long (nullable = false) | |--_6: Float (nullable = false) | |--_7: Double (nullable = false) - |""".stripMargin + |""".stripMargin) } test("Nullable types") { - TypeToSchemaConverter - .inferSchema[( - Option[Int], - JavaBoolean, - JavaByte, - JavaShort, - JavaInteger, - JavaLong, - JavaFloat, - JavaDouble, - Array[Boolean], - Map[String, Double], - JavaBigDecimal, - BigDecimal, - Variant, - Geography, - Date, - Time, - Timestamp, - Geometry)]() - .treeString(0) == - """root + assert( + TypeToSchemaConverter + .inferSchema[( + Option[Int], + JavaBoolean, + JavaByte, + JavaShort, + JavaInteger, + JavaLong, + JavaFloat, + JavaDouble, + Array[Boolean], + Map[String, Double], + JavaBigDecimal, + BigDecimal, + Variant, + Geography, + Date, + Time, + Timestamp, + Geometry)]() + .treeString(0) == + """root | |--_1: Integer (nullable = true) | |--_2: Boolean (nullable = true) | |--_3: Byte (nullable = true) @@ -246,7 +251,7 @@ class UtilsSuite extends SNTestBase { | |--_16: Time (nullable = true) | |--_17: Timestamp (nullable = true) | |--_18: Geometry (nullable = true) - |""".stripMargin + |""".stripMargin) } test("normalizeStageLocation") { @@ -414,6 +419,7 @@ class UtilsSuite extends SNTestBase { // println(s"test: $name") Utils.validateObjectName(name) } + succeed } test("test Utils.getUDFUploadPrefix()") { @@ -432,6 +438,7 @@ class UtilsSuite extends SNTestBase { // println(s"test: $name") assert(Utils.getUDFUploadPrefix(name).matches("[\\w]+")) } + succeed } test("negative test Utils.validateObjectName()") { @@ -484,6 +491,7 @@ class UtilsSuite extends SNTestBase { val ex = intercept[SnowparkClientException] { Utils.validateObjectName(name) } assert(ex.getMessage.replaceAll("\n", "").matches(".*The object name .* is invalid.")) } + succeed } test("os name") { @@ -506,6 +514,7 @@ class UtilsSuite extends SNTestBase { testItems.foreach { item => assert(Utils.convertWindowsPathToLinux(item._1).equals(item._2)) } + succeed } test("Utils.version matches pom version") { @@ -533,6 +542,7 @@ class UtilsSuite extends SNTestBase { val sleep7 = Utils.retrySleepTimeInMS(7) assert(sleep7 >= 30000 && sleep7 <= 60000) } + succeed } test("Utils.isRetryable") { @@ -675,7 +685,7 @@ class UtilsSuite extends SNTestBase { } object LoggingTester extends Logging { - def test(): Unit = { + def test(): Assertion = { // no error report val password = "failed_error_log_password" logInfo(s"info PASSWORD=$password") @@ -692,5 +702,6 @@ object LoggingTester extends Logging { logWarning(s"warning PASSWORD=$password", exception) logError(s"error PASSWORD=$password", exception) + succeed } } diff --git a/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala b/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala index 84ad53bf..dd7cd02c 100644 --- a/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/AsyncJobSuite.scala @@ -4,10 +4,9 @@ import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.types._ import com.snowflake.snowpark._ import net.snowflake.client.jdbc.SnowflakeSQLException -import org.scalatest.{BeforeAndAfterEach, Tag} +import org.scalatest.{Assertion, BeforeAndAfterEach, Tag} import scala.concurrent.Future -import scala.concurrent.ExecutionContext.Implicits.global import scala.util.{Failure, Random, Success} class AsyncJobSuite extends TestData with BeforeAndAfterEach { @@ -352,7 +351,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { // This function is copied from DataFrameReader.testReadFile def testReadFile(testName: String, testTags: Tag*)( - thunk: (() => DataFrameReader) => Unit): Unit = { + thunk: (() => DataFrameReader) => Assertion): Unit = { // test select test(testName + " - SELECT", testTags: _*) { thunk(() => session.read) @@ -501,6 +500,7 @@ class AsyncJobSuite extends TestData with BeforeAndAfterEach { df.write.mode(SaveMode.Overwrite).async.saveAsTable(list).getResult() checkAnswer(session.table(tableName), Seq(Row(1), Row(2), Row(3))) dropTable(tableName) + succeed } finally { dropTable(tableName) } diff --git a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala index df2e0e49..c2d9f477 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ColumnSuite.scala @@ -240,26 +240,26 @@ class ColumnSuite extends TestData { checkAnswer(df.select(df("one")), Row(1) :: Nil) checkAnswer(df.select(df("oNe")), Row(1) :: Nil) checkAnswer(df.select(df(""""ONE"""")), Row(1) :: Nil) - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df.col(""""One"""") } df = Seq((1)).toDF("One One") checkAnswer(df.select(df("One One")), Row(1) :: Nil) checkAnswer(df.select(df("\"One One\"")), Row(1) :: Nil) - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df.col(""""one one"""") } - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df("one one") } - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df(""""ONE ONE"""") } df = Seq((1)).toDF(""""One One"""") checkAnswer(df.select(df(""""One One"""")), Row(1) :: Nil) - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df.col(""""ONE ONE"""") } } @@ -333,6 +333,7 @@ class ColumnSuite extends TestData { val colsResolved = df.schema.fields.map(_.name).map(df.apply).toSeq val df2 = df.select(colsUnresolved ++ colsResolved) df2.collect() + succeed } finally { session.sql(s"drop function ${temp}.${udfName}(integer)").collect() session.sql(s"drop table ${temp}.${sName}").collect() @@ -346,13 +347,13 @@ class ColumnSuite extends TestData { checkAnswer(df.select(col("\"col\"\"\"")), Row(1) :: Nil) checkAnswer(df.select(col("\"col\"")), Row(2) :: Nil) checkAnswer(df.select(col("\"\"\"col\"")), Row(3) :: Nil) - intercept[Exception] { + assertThrows[Exception] { df.select(col("\"col\"\"")).collect() } - intercept[Exception] { + assertThrows[Exception] { df.select(col("\"\"col\"")).collect() } - intercept[Exception] { + assertThrows[Exception] { df.select(col("\"col\"\"\"\"")).collect() } } @@ -366,13 +367,13 @@ class ColumnSuite extends TestData { checkAnswer(df.select(col("COL")), Row(1) :: Nil) checkAnswer(df.select(col("CoL")), Row(1) :: Nil) checkAnswer(df.select(col("\"COL\"")), Row(1) :: Nil) - intercept[Exception] { + assertThrows[Exception] { df.select(col("\"Col\"")).collect() } - intercept[Exception] { + assertThrows[Exception] { df.select(col("COL .")).collect() } - intercept[Exception] { + assertThrows[Exception] { df.select(col("\"CoL\"")).collect() } } @@ -386,13 +387,13 @@ class ColumnSuite extends TestData { checkAnswer(df.select($"COL"), Row(1) :: Nil) checkAnswer(df.select($"CoL"), Row(1) :: Nil) checkAnswer(df.select($""""COL""""), Row(1) :: Nil) - intercept[Exception] { + assertThrows[Exception] { df.select($""""Col"""").collect() } - intercept[Exception] { + assertThrows[Exception] { df.select($"COL .").collect() } - intercept[Exception] { + assertThrows[Exception] { df.select($""""CoL"""").collect() } } @@ -406,10 +407,10 @@ class ColumnSuite extends TestData { checkAnswer(df.select("COL"), Row(1) :: Nil) checkAnswer(df.select("CoL"), Row(1) :: Nil) checkAnswer(df.select("\"COL\""), Row(1) :: Nil) - intercept[Exception] { + assertThrows[Exception] { df.select("\"Col\"").collect() } - intercept[Exception] { + assertThrows[Exception] { df.select("COL .").collect() } } @@ -433,16 +434,16 @@ class ColumnSuite extends TestData { checkAnswer(df.select(sqlExpr("\"col\" + 10")), Row(12) :: Nil) checkAnswer(df.filter(sqlExpr("col < 1")), Nil) checkAnswer(df.filter(sqlExpr("\"col\" = 2")).select(col("col")), Row(1) :: Nil) - intercept[Exception] { + assertThrows[Exception] { df.select(sqlExpr("\"Col\"")).collect() } - intercept[Exception] { + assertThrows[Exception] { df.select(sqlExpr("COL .")).collect() } - intercept[Exception] { + assertThrows[Exception] { df.select(sqlExpr("\"CoL\"")).collect() } - intercept[Exception] { + assertThrows[Exception] { df.select(sqlExpr("col .")).collect() } } @@ -709,7 +710,7 @@ class ColumnSuite extends TestData { new Date(timestamp - 100))) } - val df = session.createDataFrame(largeData, schema) + val df = session.createDataFrame(largeData.toSeq, schema) // scala style checks dosn't support to put all of these expression in one filter() // So split it as 2 steps. val df2 = df.filter( diff --git a/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala index 485bf4c2..e59398a3 100644 --- a/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/CopyableDataFrameSuite.scala @@ -944,6 +944,7 @@ class CopyableDataFrameSuite extends SNTestBase { val cloned = df.clone assert(cloned.isInstanceOf[CopyableDataFrame]) cloned.copyInto(testTableName, Seq(col("$1").as("B"))) + succeed } finally { dropTable(testTableName) } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala index 3753a480..445fb392 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameAggregateSuite.scala @@ -3,7 +3,6 @@ package com.snowflake.snowpark_test import com.snowflake.snowpark.functions._ import com.snowflake.snowpark._ import net.snowflake.client.jdbc.SnowflakeSQLException -import org.scalatest.Matchers.the import java.sql.ResultSet @@ -195,6 +194,7 @@ class DataFrameAggregateSuite extends TestData { assert(values_1_B.contains(row.getInt(1)) && values_2_B.contains(row.getInt(2))) } } + succeed } test("RelationalGroupedDataFrame.avg()/mean()") { @@ -308,7 +308,7 @@ class DataFrameAggregateSuite extends TestData { // Used temporary VIEW which is not supported by owner's mode stored proc yet test("Window functions inside aggregate functions", JavaStoredProcExcludeOwner) { def checkWindowError(df: => DataFrame): Unit = { - the[SnowflakeSQLException] thrownBy { + assertThrows[SnowflakeSQLException] { df.collect() } } @@ -477,13 +477,13 @@ class DataFrameAggregateSuite extends TestData { /* TODO: Add another test with eager analysis */ - intercept[SnowflakeSQLException] { + assertThrows[SnowflakeSQLException] { courseSales.groupBy().agg(grouping($"course")).collect() } /* * TODO: Add another test with eager analysis */ - intercept[SnowflakeSQLException] { + assertThrows[SnowflakeSQLException] { courseSales.groupBy().agg(grouping_id($"course")).collect() } } @@ -546,6 +546,7 @@ class DataFrameAggregateSuite extends TestData { checkAnswer(kurtosisVal, Seq(Row(aggKurtosisResult.getDouble(1)))) } statement.close() + succeed } test("SN - zero moments") { @@ -699,7 +700,7 @@ class DataFrameAggregateSuite extends TestData { checkAnswer(session.sql("SELECT x FROM tempView GROUP BY x HAVING COUNT_IF(NULL) > 0"), Nil) - val error = intercept[SnowflakeSQLException] { + assertThrows[SnowflakeSQLException] { session.sql("SELECT COUNT_IF(x) FROM tempView").collect() } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala index ce04e229..82962997 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameJoinSuite.scala @@ -222,6 +222,7 @@ trait DataFrameJoinSuite extends SNTestBase { assert(ex2.getMessage.contains("INTCOL")) assert(ex2.getMessage().contains("ambiguous")) } + succeed } test("join -- expressions on ambiguous columns") { @@ -304,7 +305,7 @@ trait DataFrameJoinSuite extends SNTestBase { lhs.join(rhs, Seq("intcol"), joinType).select(lhs("negcol"), rhs("negcol")), Row(-1, -10) :: Row(-2, -20) :: Nil) } - + succeed } test("Columns with and without quotes") { @@ -400,6 +401,7 @@ trait DataFrameJoinSuite extends SNTestBase { assert(ex.message.contains(msg)) } } + succeed } test("clone can help these self joins") { @@ -616,6 +618,7 @@ trait DataFrameJoinSuite extends SNTestBase { val dfOne = df.select(lit(1).as("a")) val dfTwo = session.range(10).select(lit(2).as("b")) dfOne.join(dfTwo, $"a" === $"b", "left").collect() + succeed } test("name alias in multiple join") { @@ -645,6 +648,7 @@ trait DataFrameJoinSuite extends SNTestBase { df_end_stations("station_name"), df_trips("starttime")) .collect() + succeed } finally { dropTable(tableTrips) @@ -679,6 +683,7 @@ trait DataFrameJoinSuite extends SNTestBase { df_end_stations("station%name"), df_trips("starttime")) .collect() + succeed } finally { dropTable(tableTrips) @@ -806,6 +811,7 @@ trait DataFrameJoinSuite extends SNTestBase { |------------------- |""".stripMargin) df.select(dfRight("*"), dfRight("c")).show() + succeed } test("select columns on join result with conflict name", JavaStoredProcExclude) { diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala index 2146cdce..28b84d8b 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameRangeSuite.scala @@ -105,6 +105,7 @@ class DataFrameRangeSuite extends SNTestBase { assert(res.head.getLong(1) == expSum) } } + succeed } test("Session range with Max and Min") { diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala index 82cd6e7f..89b87c87 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameReaderSuite.scala @@ -4,7 +4,7 @@ import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.types._ import com.snowflake.snowpark._ import net.snowflake.client.jdbc.SnowflakeSQLException -import org.scalatest.Tag +import org.scalatest.{Assertion, Tag} import scala.util.Random @@ -226,6 +226,7 @@ class DataFrameReaderSuite extends SNTestBase { .csv(path), result) }) + succeed } finally { runQuery(s"drop file format $formatName", session) } @@ -338,6 +339,7 @@ class DataFrameReaderSuite extends SNTestBase { .parquet(s"@$tmpStageName/$ctype/") .collect() }) + succeed }) testReadFile("read parquet with no schema")(reader => { @@ -502,7 +504,7 @@ class DataFrameReaderSuite extends SNTestBase { }) def testReadFile(testName: String, testTags: Tag*)( - thunk: (() => DataFrameReader) => Unit): Unit = { + thunk: (() => DataFrameReader) => Assertion): Unit = { // test select test(testName + " - SELECT", testTags: _*) { thunk(() => session.read) diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala index 39629dec..e4a76c93 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSetOperationsSuite.scala @@ -4,6 +4,7 @@ import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.types.IntegerType import com.snowflake.snowpark.{Column, Row, SnowparkClientException, TestData} import net.snowflake.client.jdbc.SnowflakeSQLException +import org.scalatest.Assertion import java.sql.{Date, Timestamp} @@ -11,7 +12,7 @@ class DataFrameSetOperationsSuite extends TestData { import session.implicits._ test("Union with filters") { - def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + def check(newCol: Column, filter: Column, result: Seq[Row]): Assertion = { val df1 = session.createDataFrame(Seq((1, 1))).toDF("a", "b").withColumn("c", newCol) val df2 = df1.union(df1).withColumn("d", lit(100)).filter(filter) @@ -27,7 +28,7 @@ class DataFrameSetOperationsSuite extends TestData { } test("Union All with filters") { - def check(newCol: Column, filter: Column, result: Seq[Row]): Unit = { + def check(newCol: Column, filter: Column, result: Seq[Row]): Assertion = { val df1 = session.createDataFrame(Seq((1, 1))).toDF("a", "b").withColumn("c", newCol) val df2 = df1.unionAll(df1).withColumn("d", lit(100)).filter(filter) @@ -161,7 +162,7 @@ class DataFrameSetOperationsSuite extends TestData { df1 = Seq((1, 2, 3)).toDF("a", "b", "c") df2 = Seq((4, 5, 6)).toDF("a", "c", "d") - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df1.unionByName(df2) } } @@ -182,7 +183,7 @@ class DataFrameSetOperationsSuite extends TestData { df1 = Seq((1, 2, 3)).toDF("a", "b", "c") df2 = Seq((4, 5, 6)).toDF("a", "c", "d") - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df1.unionAllByName(df2) } } @@ -196,7 +197,7 @@ class DataFrameSetOperationsSuite extends TestData { df1 = Seq((1, 2, 3)).toDF(""""a"""", "b", "c") df2 = Seq((4, 5, 6)).toDF("a", "c", "b") - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df1.unionByName(df2) } } @@ -210,7 +211,7 @@ class DataFrameSetOperationsSuite extends TestData { df1 = Seq((1, 2, 3)).toDF(""""a"""", "b", "c") df2 = Seq((4, 5, 6)).toDF("a", "c", "b") - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df1.unionAllByName(df2) } } @@ -255,6 +256,7 @@ class DataFrameSetOperationsSuite extends TestData { dates.union(widenTypedRows).collect() dates.except(widenTypedRows).collect() dates.intersect(widenTypedRows).collect() + succeed } /* @@ -271,7 +273,7 @@ class DataFrameSetOperationsSuite extends TestData { } df1 = Seq((1, 1)).toDF("c0", "c1") df2 = Seq((1, 1)).toDF(c0, c1) - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df1.unionByName(df2) } } @@ -286,7 +288,7 @@ class DataFrameSetOperationsSuite extends TestData { } df1 = Seq((1, 1)).toDF("c0", "c1") df2 = Seq((1, 1)).toDF(c0, c1) - intercept[SnowparkClientException] { + assertThrows[SnowparkClientException] { df1.unionAllByName(df2) } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala index 45aec853..67ce38f5 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataFrameSuite.scala @@ -5,7 +5,8 @@ import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.internal.analyzer._ import com.snowflake.snowpark.types._ import net.snowflake.client.jdbc.SnowflakeSQLException -import org.scalatest.BeforeAndAfterEach +import org.scalatest.{Assertion, BeforeAndAfterEach} + import java.sql.{Date, Time, Timestamp} import scala.util.Random @@ -159,7 +160,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { skipIfParamNotExist = true) } - private def testCacheResult(): Unit = { + private def testCacheResult(): Assertion = { val tableName = randomName() runQuery(s"create table $tableName (num int)", session) runQuery(s"insert into $tableName values(1),(2)", session) @@ -446,8 +447,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { test("df.stat.approxQuantile", JavaStoredProcExclude) { assert(approxNumbers.stat.approxQuantile("a", Array(0.5))(0).get == 4.5) assert( - approxNumbers.stat.approxQuantile("a", Array(0, 0.1, 0.4, 0.6, 1)).deep == - Array(Some(0.0), Some(0.9), Some(3.6), Some(5.3999999999999995), Some(9.0)).deep) + approxNumbers.stat.approxQuantile("a", Array(0, 0.1, 0.4, 0.6, 1)).toSeq == + Seq(Some(0.0), Some(0.9), Some(3.6), Some(5.3999999999999995), Some(9.0))) // Probability out of range error and apply on string column error. assertThrows[SnowflakeSQLException](approxNumbers.stat.approxQuantile("a", Array(-1d))) @@ -457,8 +458,8 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert(session.table(tableName).stat.approxQuantile("num", Array(0.5))(0).isEmpty) val res = double2.stat.approxQuantile(Array("a", "b"), Array(0, 0.1, 0.6)) - assert(res(0).deep == Array(Some(0.1), Some(0.12000000000000001), Some(0.22)).deep) - assert(res(1).deep == Array(Some(0.5), Some(0.52), Some(0.62)).deep) + assert(res(0).toSeq == Seq(Some(0.1), Some(0.12000000000000001), Some(0.22))) + assert(res(1).toSeq == Seq(Some(0.5), Some(0.52), Some(0.62))) // ApproxNumbers2 contains a column called T, which conflicts with tmpColumnName. // This test demos that the query still works. @@ -630,12 +631,12 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { assert(nullData1.first().get == Row(null)) assert(integer1.filter(col("a") < 0).first().isEmpty) - integer1.first(2) sameElements Seq(Row(1), Row(2)) + assert(integer1.first(2) sameElements Seq(Row(1), Row(2))) // return all elements - integer1.first(3) sameElements Seq(Row(1), Row(2), Row(3)) - integer1.first(10) sameElements Seq(Row(1), Row(2), Row(3)) - integer1.first(-10) sameElements Seq(Row(1), Row(2), Row(3)) + assert(integer1.first(3) sameElements Seq(Row(1), Row(2), Row(3))) + assert(integer1.first(10) sameElements Seq(Row(1), Row(2), Row(3))) + assert(integer1.first(-10) sameElements Seq(Row(1), Row(2), Row(3))) } test("sample() with row count") { @@ -726,7 +727,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { weights: Array[Double], index: Int, count: Long, - TotalCount: Long): Unit = { + TotalCount: Long): Assertion = { val expectedRowCount = TotalCount * weights(index) / weights.sum assert(Math.abs(expectedRowCount - count) < expectedRowCount * samplingDeviation) } @@ -1514,6 +1515,7 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { test("createDataFrame from empty Seq with schema inference") { Seq.empty[(Int, Int)].toDF("a", "b").schema.printTreeString() + succeed } test("schema inference binary type") { @@ -2182,11 +2184,11 @@ trait DataFrameSuite extends TestData with BeforeAndAfterEach { val row1 = df1.collect().head // result is non-deterministic. (row1.getInt(0), row1.getInt(1), row1.getInt(2), row1.getInt(3)) match { - case (1, 1, 1, 1) => - case (1, 1, 1, 2) => - case (1, 1, 2, 3) => - case (1, 2, 3, 4) => - case _ => throw new Exception("wrong result") + case (1, 1, 1, 1) => succeed + case (1, 1, 1, 2) => succeed + case (1, 1, 2, 3) => succeed + case (1, 2, 3, 4) => succeed + case _ => fail("wrong result") } } diff --git a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala index 0d19d5e8..2e1bcc14 100644 --- a/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/DataTypeSuite.scala @@ -42,6 +42,7 @@ class DataTypeSuite extends SNTestBase { } Seq(ByteType, ShortType, IntegerType, LongType).foreach(verifyIntegralType) + succeed } test("FractionalType") { @@ -55,6 +56,7 @@ class DataTypeSuite extends SNTestBase { } Seq(FloatType, DoubleType).foreach(verifyIntegralType) + succeed } test("DecimalType") { diff --git a/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala index aac95147..9dba1fe8 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FileOperationSuite.scala @@ -162,6 +162,7 @@ class FileOperationSuite extends SNTestBase { assert(secondResult.length == 3) // On GCP, the files are not skipped if target file already exists secondResult.map(row => assert(row.status.equals("SKIPPED") || row.status.equals("UPLOADED"))) + succeed } test("put() negative test") { @@ -453,6 +454,7 @@ class FileOperationSuite extends SNTestBase { fileName = s"streamFile_${TestUtils.randomString(5)}.csv" testStreamRoundTrip(s"$schema.$tempStage/$fileName", s"$schema.$tempStage/$fileName.gz", true) + succeed } @@ -469,6 +471,7 @@ class FileOperationSuite extends SNTestBase { s"$randomNewSchema.$tempStage/$fileName", s"$randomNewSchema.$tempStage/$fileName.gz", true) + succeed } finally { session.sql(s"DROP SCHEMA $randomNewSchema").collect() } diff --git a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala index e473de12..cc07015b 100644 --- a/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/FunctionSuite.scala @@ -235,6 +235,7 @@ trait FunctionSuite extends TestData { val df = session.sql("select 1") df.select(random(123)).collect() df.select(random()).collect() + succeed } test("sqrt") { @@ -1907,6 +1908,7 @@ trait FunctionSuite extends TestData { approx_percentile_estimate(approx_percentile_accumulate(col("a")), 0.5)), approxNumbers.select(approx_percentile(col("a"), 0.5)), sort = false) + succeed } test("approx_percentile_combine") { @@ -2136,6 +2138,7 @@ trait FunctionSuite extends TestData { assert(row.getInt(0) >= 1) assert(row.getInt(0) <= 5) }) + succeed } test("listagg") { diff --git a/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala b/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala index 7ffe6950..159d5e93 100644 --- a/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/IndependentClassSuite.scala @@ -1,6 +1,6 @@ package com.snowflake.snowpark_test -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} import org.scalatest.exceptions.TestFailedException import scala.language.postfixOps diff --git a/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala b/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala index 7bec5be4..796935e5 100644 --- a/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/JavaUtilsSuite.scala @@ -1,6 +1,6 @@ package com.snowflake.snowpark_test -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} import com.snowflake.snowpark.internal.JavaUtils._ import com.snowflake.snowpark.types.Variant import com.snowflake.snowpark_java.types.{Variant => JavaVariant} diff --git a/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala b/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala index a2f70f98..e91f7a9d 100644 --- a/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/LargeDataFrameSuite.scala @@ -69,6 +69,7 @@ class LargeDataFrameSuite extends TestData { (0 until (result.length - 1)).foreach(index => assert(result(index).getInt(0) < result(index + 1).getInt(0))) + succeed } test("createDataFrame for large values: basic types") { @@ -128,12 +129,13 @@ class LargeDataFrameSuite extends TestData { largeData.append( Row(1025, null, null, null, null, null, null, null, null, null, null, null, null)) - val result = session.createDataFrame(largeData, schema) + val largeDataSeq = largeData.toSeq + val result = session.createDataFrame(largeDataSeq, schema) // byte, short, int, long are converted to long // float and double are converted to double result.schema.printTreeString() assert(getSchemaString(result.schema) == schemaString) - checkAnswer(result.sort(col("id")), largeData, false) + checkAnswer(result.sort(col("id")), largeDataSeq, sort = false) } test("createDataFrame for large values: time") { @@ -146,7 +148,7 @@ class LargeDataFrameSuite extends TestData { } largeData.append(Row(rowCount, null)) - val df = session.createDataFrame(largeData, schema) + val df = session.createDataFrame(largeData.toSeq, schema) assert( getSchemaString(df.schema) == """root @@ -160,7 +162,7 @@ class LargeDataFrameSuite extends TestData { expected.append(Row(i.toLong, snowflakeTime)) } expected.append(Row(rowCount, null)) - checkAnswer(df.sort(col("id")), expected, sort = false) + checkAnswer(df.sort(col("id")), expected.toSeq, sort = false) } // In the result, Array, Map and Geography are String data @@ -188,7 +190,7 @@ class LargeDataFrameSuite extends TestData { } largeData.append(Row(rowCount, null, null, null, null, null)) - val df = session.createDataFrame(largeData, schema) + val df = session.createDataFrame(largeData.toSeq, schema) assert( getSchemaString(df.schema) == """root @@ -224,7 +226,7 @@ class LargeDataFrameSuite extends TestData { |}""".stripMargin))) } expected.append(Row(rowCount, null, null, null, null, null)) - checkAnswer(df.sort(col("id")), expected, sort = false) + checkAnswer(df.sort(col("id")), expected.toSeq, sort = false) } test("createDataFrame for large values: variant in array and map") { @@ -240,13 +242,13 @@ class LargeDataFrameSuite extends TestData { Row(i.toLong, Array(new Variant(1), new Variant("\"'")), Map("a" -> new Variant("\"'")))) } largeData.append(Row(rowCount, null, null)) - val df = session.createDataFrame(largeData, schema) + val df = session.createDataFrame(largeData.toSeq, schema) val expected = new ArrayBuffer[Row]() for (i <- 0 until rowCount) { expected.append(Row(i.toLong, "[\n 1,\n \"\\\"'\"\n]", "{\n \"a\": \"\\\"'\"\n}")) } expected.append(Row(rowCount, null, null)) - checkAnswer(df.sort(col("id")), expected, sort = false) + checkAnswer(df.sort(col("id")), expected.toSeq, sort = false) } test("createDataFrame for large values: geography in array and map") { @@ -269,7 +271,7 @@ class LargeDataFrameSuite extends TestData { "b" -> Geography.fromGeoJSON("{\"type\":\"Point\",\"coordinates\":[300,100]}")))) } largeData.append(Row(rowCount, null, null)) - val df = session.createDataFrame(largeData, schema) + val df = session.createDataFrame(largeData.toSeq, schema) val expected = new ArrayBuffer[Row]() for (i <- 0 until rowCount) { expected.append( @@ -281,7 +283,7 @@ class LargeDataFrameSuite extends TestData { " 300,\n 100\n ],\n \"type\": \"Point\"\n }\n}")) } expected.append(Row(rowCount, null, null)) - checkAnswer(df.sort(col("id")), expected, sort = false) + checkAnswer(df.sort(col("id")), expected.toSeq, sort = false) } test("test large ResultSet with multiple chunks") { diff --git a/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala index dcb35cdb..edb02a85 100644 --- a/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/PermanentUDFSuite.scala @@ -111,6 +111,7 @@ class PermanentUDFSuite extends TestData { checkAnswer(df.select(callUDF(permFuncName, df("a"))), Seq(Row(2), Row(3))) runQuery(s"drop function $tempFuncName(INT)", session) runQuery(s"drop function $permFuncName(INT)", session) + succeed } finally { runQuery(s"drop function if exists $tempFuncName(INT)", session) runQuery(s"drop function if exists $permFuncName(INT)", session) @@ -153,6 +154,7 @@ class PermanentUDFSuite extends TestData { } assert(ex.getMessage.matches(".*The object name .* is invalid.")) } + succeed } test("Clean up uploaded jar files if UDF registration fails", JavaStoredProcExclude) { diff --git a/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala index 93dbc99c..5a85c412 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ResultSchemaSuite.scala @@ -58,6 +58,7 @@ class ResultSchemaSuite extends TestData { verifySchema( "alter session set ABORT_DETACHED_QUERY=false", session.sql("alter session set ABORT_DETACHED_QUERY=false").schema) + succeed } test("list, remove file") { @@ -75,10 +76,12 @@ class ResultSchemaSuite extends TestData { verifySchema( s"remove @$stageName/$testFileCsv", session.sql(s"remove @$stageName/$testFile2Csv").schema) + succeed } test("select") { verifySchema(s"select * from $tableName", session.sql(s"select * from $tableName").schema) + succeed } test("analyze schema") { @@ -90,6 +93,7 @@ class ResultSchemaSuite extends TestData { verifySchema( s"""select string, "int", array, "date" from $fullTypesTable where \"int\" > 0""", df2.schema) + succeed } // ignore it for now since we are modifying the analyzer system. @@ -146,6 +150,7 @@ class ResultSchemaSuite extends TestData { assert(tsSchema(index).dataType == typeMap(index).tsType) }) statement.close() + succeed } test("verify Geography schema type") { @@ -173,6 +178,7 @@ class ResultSchemaSuite extends TestData { assert(resultMeta.getColumnType(1) == Types.BINARY) assert(tsSchema.head.dataType == GeographyType) statement.close() + succeed } finally { // Assign output format to the default value runQuery(s"alter session set GEOGRAPHY_OUTPUT_FORMAT='GeoJSON'", session) @@ -204,6 +210,7 @@ class ResultSchemaSuite extends TestData { assert(resultMeta.getColumnType(1) == Types.BINARY) assert(tsSchema.head.dataType == GeometryType) statement.close() + succeed } finally { // Assign output format to the default value runQuery(s"alter session set GEOMETRY_OUTPUT_FORMAT='GeoJSON'", session) @@ -217,5 +224,6 @@ class ResultSchemaSuite extends TestData { assert(resultMeta.getColumnType(1) == Types.TIME) assert(tsSchema.head.dataType == TimeType) statement.close() + succeed } } diff --git a/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala index 35e8c572..757fd728 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ScalaVariantSuite.scala @@ -1,7 +1,7 @@ package com.snowflake.snowpark_test import com.snowflake.snowpark.types.{Geography, Variant} -import org.scalatest.FunSuite +import org.scalatest.funsuite.{AnyFunSuite => FunSuite} import java.io.UncheckedIOException import java.sql.{Date, Time, Timestamp} @@ -64,7 +64,7 @@ class ScalaVariantSuite extends FunSuite { assert(vFloat.asLong() == 1L) assert(vFloat.asInt() == 1) assert(vFloat.asShort() == 1.toShort) - assert((vFloat.asBigDecimal() - BigDecimal("1.1")).abs.doubleValue() < 0.0000001) + assert((vFloat.asBigDecimal() - BigDecimal("1.1")).abs.doubleValue < 0.0000001) assert(vFloat.asBigInt() == BigInt("1")) assert(vFloat.asTimestamp().equals(new Timestamp(1L))) assert(vFloat.asString().equals("1.1")) diff --git a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala index 9b82ba7b..9d8a990b 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SessionSuite.scala @@ -47,6 +47,7 @@ class SessionSuite extends SNTestBase { t1.run() t2.run() + succeed } test("Test for get or create session") { diff --git a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala index a0ae3be0..75442807 100644 --- a/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/SqlSuite.scala @@ -76,7 +76,9 @@ trait SqlSuite extends SNTestBase { // add spaces to the query val putQuery = - s" put ${TestUtils.escapePath(path.toString).replace("file:/", "file:///")} @$stageName " + s" put ${TestUtils + .escapePath(path.toString) + .replace("file:/", "file:" + "/" + "/" + "/")} @$stageName " val put = session.sql(putQuery) put.schema.printTreeString() // should upload nothing @@ -99,6 +101,7 @@ trait SqlSuite extends SNTestBase { // TODO: Below assertion failed on GCP because JDBC bug SNOW-493080 // Disable this check temporally // assert(new File(s"$outputPath/$fileName.gz").exists()) + succeed } finally { // remove tmp file new Directory(new File(outputPath)).deleteRecursively() diff --git a/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala b/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala index 8f214271..7cf8881e 100644 --- a/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/StoredProcedureSuite.scala @@ -145,6 +145,7 @@ class StoredProcedureSuite extends SNTestBase { test("decimal input") { val sp = session.sproc.registerTemporary((_: Session, num: java.math.BigDecimal) => num) session.storedProcedure(sp, java.math.BigDecimal.valueOf(123)).show() + succeed } test("binary type") { @@ -2286,6 +2287,7 @@ println(s""" assert(newSession.sql(s"show procedures like '$name1'").collect().isEmpty) assert(newSession.sql(s"show procedures like '$name2'").collect().isEmpty) newSession.close() + succeed } // temporary disabled, waiting for server side JDBC upgrade diff --git a/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala b/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala index 2d5357fe..0cc490a9 100644 --- a/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/TableSuite.scala @@ -181,6 +181,7 @@ class TableSuite extends TestData { checkAnswer(session.table(name2), Seq(Row(1), Row(2), Row(3))) dropTable(name2) + succeed } test("read from different schema") { @@ -308,6 +309,7 @@ class TableSuite extends TestData { df.write.mode(SaveMode.Overwrite).saveAsTable(list) checkAnswer(session.table(tableName), Seq(Row(1), Row(2), Row(3))) dropTable(tableName) + succeed } finally { dropTable(tableName) diff --git a/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala index 61ed76a0..a106e567 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UDFSuite.scala @@ -89,15 +89,15 @@ trait UDFSuite extends TestData { } assert(ex.getMessage.contains("mutable.Map[String,String]")) - intercept[UnsupportedOperationException] { + assertThrows[UnsupportedOperationException] { udf((x: java.lang.Integer) => x + x) } - intercept[UnsupportedOperationException] { + assertThrows[UnsupportedOperationException] { udf((x: Option[String]) => x.map(a => s"$a")) } - intercept[UnsupportedOperationException] { + assertThrows[UnsupportedOperationException] { udf((x: mutable.Map[String, String]) => x.keys) } } @@ -143,6 +143,7 @@ trait UDFSuite extends TestData { assert(result.size == 2) Seq(1, 2, 3).foreach(i => assert(result(0).getString(0).contains(s"convertToMap$i"))) Seq(4, 5, 6).foreach(i => assert(result(1).getString(0).contains(s"convertToMap$i"))) + succeed } test("UDF with multiple args of type map, array etc") { @@ -2816,6 +2817,7 @@ class AlwaysCleanUDFSuite extends UDFSuite with AlwaysCleanSession { val myDf = session.sql("select 'Raymond' NAME") val readFileUdf = udf(TestClassWithoutFieldAccess.run) myDf.withColumn("CONCAT", readFileUdf(col("NAME"))).show() + succeed } } @@ -2824,7 +2826,7 @@ class NeverCleanUDFSuite extends UDFSuite with NeverCleanSession { test("Test without closure cleaner") { val myDf = session.sql("select 'Raymond' NAME") // Without closure cleaner, this test will throw error - intercept[NotSerializableException] { + assertThrows[NotSerializableException] { val readFileUdf = udf(TestClassWithoutFieldAccess.run) myDf.withColumn("CONCAT", readFileUdf(col("NAME"))).show() } diff --git a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala index d6845bde..fec0da6d 100644 --- a/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/UDTFSuite.scala @@ -12,7 +12,7 @@ import com.snowflake.snowpark.internal._ import com.snowflake.snowpark.types._ import com.snowflake.snowpark.udtf._ -import scala.collection.{Seq, mutable} +import scala.collection.mutable @UDFTest class UDTFSuite extends TestData { @@ -292,6 +292,7 @@ class UDTFSuite extends TestData { session.udtf.registerTemporary(funcName, new MyDuplicateRegisterUDTF1()) session.udtf.registerTemporary(funcName2, new MyDuplicateRegisterUDTF1()) + succeed } finally { runQuery(s"drop function if exists $funcName(STRING)", session) } @@ -421,6 +422,7 @@ class UDTFSuite extends TestData { |""".stripMargin) checkAnswer(df, Seq(Row(30), Row(31), Row(32), Row(33), Row(34)), false) } + succeed } finally { runQuery(s"drop function if exists $funcName(NUMBER, NUMBER)", session) } @@ -2138,6 +2140,7 @@ class UDTFSuite extends TestData { Seq(Row("a", null, "Map(b -> 2, c -> 1)"), Row("d", null, "Map(e -> 1)"))) df.join(tf, Map("arg1" -> df("b")), Seq.empty, Seq(df("b"))).show() df.join(tf, Map("arg1" -> df("b")), Seq.empty, Seq.empty).show() + succeed } test("single partition") { diff --git a/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala b/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala index 79acd450..91d43641 100644 --- a/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/ViewSuite.scala @@ -108,6 +108,7 @@ class ViewSuite extends TestData { checkAnswer(session.table(viewName), Seq(Row(1), Row(2), Row(3))) dropView(viewName) + succeed } finally { dropView(viewName) } diff --git a/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala b/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala index 89e8cbd4..f5e4dce3 100644 --- a/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/WindowFramesSuite.scala @@ -114,14 +114,14 @@ class WindowFramesSuite extends TestData { .over(window.rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing))), Seq(Row(1, 1))) - intercept[SnowflakeSQLException]( - df.select(min($"key").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect) + assertThrows[SnowflakeSQLException]( + df.select(min($"key").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect()) - intercept[SnowflakeSQLException]( - df.select(min($"key").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect) + assertThrows[SnowflakeSQLException]( + df.select(min($"key").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect()) - intercept[SnowflakeSQLException]( - df.select(min($"key").over(window.rangeBetween(-1, 1))).collect) + assertThrows[SnowflakeSQLException]( + df.select(min($"key").over(window.rangeBetween(-1, 1))).collect()) } test("SN - range between should accept numeric values only when bounded") { @@ -136,13 +136,13 @@ class WindowFramesSuite extends TestData { Row("non_numeric", "non_numeric") :: Nil) // TODO: Add another test with eager mode enabled - intercept[SnowflakeSQLException]( + assertThrows[SnowflakeSQLException]( df.select(min($"value").over(window.rangeBetween(Window.unboundedPreceding, 1))).collect()) - intercept[SnowflakeSQLException]( + assertThrows[SnowflakeSQLException]( df.select(min($"value").over(window.rangeBetween(-1, Window.unboundedFollowing))).collect()) - intercept[SnowflakeSQLException]( + assertThrows[SnowflakeSQLException]( df.select(min($"value").over(window.rangeBetween(-1, 1))).collect()) } @@ -186,5 +186,6 @@ class WindowFramesSuite extends TestData { assert(values_1_B.contains(row.getInt(1)) && values_2_B.contains(row.getInt(2))) } } + succeed } } diff --git a/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala b/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala index 00100566..7506c34f 100644 --- a/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala +++ b/src/test/scala/com/snowflake/snowpark_test/WindowSpecSuite.scala @@ -3,7 +3,7 @@ package com.snowflake.snowpark_test import com.snowflake.snowpark.functions._ import com.snowflake.snowpark.{DataFrame, Row, TestData, Window} import net.snowflake.client.jdbc.SnowflakeSQLException -import org.scalatest.Matchers.the +import org.scalatest.matchers.should.Matchers.the import scala.reflect.ClassTag @@ -97,6 +97,7 @@ class WindowSpecSuite extends TestData { |GROUP BY a |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1 |""".stripMargin)) + succeed } test("reuse window partitionBy") { @@ -171,7 +172,7 @@ class WindowSpecSuite extends TestData { test("SN - window function should fail if order by clause is not specified") { val df = Seq((1, "1"), (2, "2"), (1, "2"), (2, "2")).toDF("key", "value") - val e = intercept[SnowflakeSQLException]( + assertThrows[SnowflakeSQLException]( // Here we missed .orderBy("key")! df.select(row_number().over(Window.partitionBy($"value"))).collect()) } @@ -329,8 +330,7 @@ class WindowSpecSuite extends TestData { test("SN - aggregation function on invalid column") { val df = Seq((1, "1")).toDF("key", "value") - val e = - intercept[SnowflakeSQLException](df.select($"key", count($"invalid").over()).collect()) + assertThrows[SnowflakeSQLException](df.select($"key", count($"invalid").over()).collect()) } test("SN - statistical functions") { diff --git a/src/test/scala/com/snowflake/test/QueryTagSuite.scala b/src/test/scala/com/snowflake/test/QueryTagSuite.scala index 8ad7f44c..760f7fa9 100644 --- a/src/test/scala/com/snowflake/test/QueryTagSuite.scala +++ b/src/test/scala/com/snowflake/test/QueryTagSuite.scala @@ -21,5 +21,6 @@ class QueryTagSuite extends SNTestBase { res.next() assert(res.getString("value").equals(someTag)) statement.close() + succeed } }