Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scala 213 #90

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<encoding>UTF-8</encoding>
<scala.version>2.12.18</scala.version>
<scala.compat.version>2.12</scala.compat.version>
<spec2.version>4.2.0</spec2.version>
<scala.version>2.13.13</scala.version>
<scala.compat.version>2.13</scala.compat.version>
<spec2.version>4.20.5</spec2.version>
<snowflake.jdbc.version>3.16.0</snowflake.jdbc.version>
<version.scala.binary>${scala.compat.version}</version.scala.binary>
<doctitle>Snowpark ${project.version}</doctitle>
Expand Down Expand Up @@ -134,7 +134,7 @@
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.compat.version}</artifactId>
<version>3.0.5</version>
<version>3.2.18</version>
<scope>test</scope>
</dependency>
<dependency>
Expand Down Expand Up @@ -496,7 +496,7 @@
<profile>
<id>test-coverage</id>
<properties>
<scala.version>2.12.15</scala.version>
<scala.version>2.13.13</scala.version>
</properties>
<build>
<plugins>
Expand Down
19 changes: 11 additions & 8 deletions src/main/scala/com/snowflake/snowpark/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1797,7 +1797,10 @@ class DataFrame private[snowpark] (
* @param firstArg The first argument to pass to the specified table function.
* @param remaining A list of any additional arguments for the specified table function.
*/
def join(func: TableFunction, firstArg: Column, remaining: Column*): DataFrame =
def join(
func: com.snowflake.snowpark.TableFunction,
firstArg: Column,
remaining: Column*): DataFrame =
join(func, firstArg +: remaining)

/**
Expand All @@ -1824,7 +1827,7 @@ class DataFrame private[snowpark] (
* object or an object that you create from the [[TableFunction]] class.
* @param args A list of arguments to pass to the specified table function.
*/
def join(func: TableFunction, args: Seq[Column]): DataFrame =
def join(func: com.snowflake.snowpark.TableFunction, args: Seq[Column]): DataFrame =
joinTableFunction(func.call(args: _*), None)

/**
Expand Down Expand Up @@ -1854,7 +1857,7 @@ class DataFrame private[snowpark] (
* @param orderBy A list of columns ordered by.
*/
def join(
func: TableFunction,
func: com.snowflake.snowpark.TableFunction,
args: Seq[Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame =
Expand Down Expand Up @@ -1892,7 +1895,7 @@ class DataFrame private[snowpark] (
* Some functions, like `flatten`, have named parameters.
* Use this map to specify the parameter names and their corresponding values.
*/
def join(func: TableFunction, args: Map[String, Column]): DataFrame =
def join(func: com.snowflake.snowpark.TableFunction, args: Map[String, Column]): DataFrame =
joinTableFunction(func.call(args), None)

/**
Expand Down Expand Up @@ -1929,7 +1932,7 @@ class DataFrame private[snowpark] (
* @param orderBy A list of columns ordered by.
*/
def join(
func: TableFunction,
func: com.snowflake.snowpark.TableFunction,
args: Map[String, Column],
partitionBy: Seq[Column],
orderBy: Seq[Column]): DataFrame =
Expand Down Expand Up @@ -2366,7 +2369,7 @@ class DataFrame private[snowpark] (
}
}
lines.append(value.substring(startIndex))
lines
lines.toSeq
}

def convertValueToString(value: Any): String =
Expand Down Expand Up @@ -2501,7 +2504,7 @@ class DataFrame private[snowpark] (
* and view name.
*/
def createOrReplaceView(multipartIdentifier: java.util.List[String]): Unit =
createOrReplaceView(multipartIdentifier.asScala)
createOrReplaceView(multipartIdentifier.asScala.toSeq)

/**
* Creates a temporary view that returns the same results as this DataFrame.
Expand Down Expand Up @@ -2564,7 +2567,7 @@ class DataFrame private[snowpark] (
* view name.
*/
def createOrReplaceTempView(multipartIdentifier: java.util.List[String]): Unit =
createOrReplaceTempView(multipartIdentifier.asScala)
createOrReplaceTempView(multipartIdentifier.asScala.toSeq)

private def doCreateOrReplaceView(viewName: String, viewType: ViewType): Unit = {
session.conn.telemetry.reportActionCreateOrReplaceView()
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/com/snowflake/snowpark/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ class Row protected (values: Array[Any]) extends Serializable {
* @since 1.13.0
* @group getter
*/
def getSeq[T](index: Int): Seq[T] = {
def getSeq[T: ClassTag](index: Int): Seq[T] = {
val result = getAs[Array[_]](index)
result.map {
case x: T => x
Expand Down
17 changes: 11 additions & 6 deletions src/main/scala/com/snowflake/snowpark/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* @since 0.2.0
*/
def table(multipartIdentifier: java.util.List[String]): Updatable =
table(multipartIdentifier.asScala)
table(multipartIdentifier.asScala.toSeq)

/**
* Returns an Updatable that points to the specified table.
Expand Down Expand Up @@ -497,7 +497,10 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* @param firstArg the first function argument of the given table function.
* @param remaining all remaining function arguments.
*/
def tableFunction(func: TableFunction, firstArg: Column, remaining: Column*): DataFrame =
def tableFunction(
func: com.snowflake.snowpark.TableFunction,
firstArg: Column,
remaining: Column*): DataFrame =
tableFunction(func, firstArg +: remaining)

/**
Expand Down Expand Up @@ -525,7 +528,7 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* referred from the built-in list from tableFunctions.
* @param args function arguments of the given table function.
*/
def tableFunction(func: TableFunction, args: Seq[Column]): DataFrame = {
def tableFunction(func: com.snowflake.snowpark.TableFunction, args: Seq[Column]): DataFrame = {
// Use df.join to apply function result if args contains a DF column
val sourceDFs = args.flatMap(_.expr.sourceDFs)
if (sourceDFs.isEmpty) {
Expand Down Expand Up @@ -569,7 +572,9 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
* Some functions, like flatten, have named parameters.
* use this map to assign values to the corresponding parameters.
*/
def tableFunction(func: TableFunction, args: Map[String, Column]): DataFrame = {
def tableFunction(
func: com.snowflake.snowpark.TableFunction,
args: Map[String, Column]): DataFrame = {
// Use df.join to apply function result if args contains a DF column
val sourceDFs = args.values.flatMap(_.expr.sourceDFs)
if (sourceDFs.isEmpty) {
Expand Down Expand Up @@ -615,9 +620,9 @@ class Session private (private[snowpark] val conn: ServerConnection) extends Log
def tableFunction(func: Column): DataFrame = {
func.expr match {
case TFunction(funcName, args) =>
tableFunction(TableFunction(funcName), args.map(Column(_)))
tableFunction(com.snowflake.snowpark.TableFunction(funcName), args.map(Column(_)))
case NamedArgumentsTableFunction(funcName, argMap) =>
tableFunction(TableFunction(funcName), argMap.map {
tableFunction(com.snowflake.snowpark.TableFunction(funcName), argMap.map {
case (key, value) => key -> Column(value)
})
case _ => throw ErrorMessage.MISC_INVALID_TABLE_FUNCTION_INPUT()
Expand Down
30 changes: 19 additions & 11 deletions src/main/scala/com/snowflake/snowpark/internal/JavaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ object JavaUtils {
def stringArrayToStringSeq(arr: Array[String]): Seq[String] = arr

def objectListToAnySeq(input: java.util.List[java.util.List[Object]]): Seq[Seq[Any]] =
input.asScala.map(list => list.asScala)
input.asScala.map(list => list.asScala.toSeq).toSeq

def registerUDF(
udfRegistration: UDFRegistration,
Expand Down Expand Up @@ -360,18 +360,26 @@ object JavaUtils {
map.asScala.toMap

def javaMapToScalaWithVariantConversion(map: java.util.Map[_, _]): Map[Any, Any] =
map.asScala.map {
case (key, value: com.snowflake.snowpark_java.types.Variant) =>
key -> InternalUtils.toScalaVariant(value)
case (key, value) => key -> value
}.toMap
map.asScala
.map((e: (Any, Any)) => {
(e) match {
case (key, value: com.snowflake.snowpark_java.types.Variant) =>
key -> InternalUtils.toScalaVariant(value)
case (key, value) => key -> value
}
})
.toMap

def scalaMapToJavaWithVariantConversion(map: Map[_, _]): java.util.Map[Object, Object] =
map.map {
case (key, value: com.snowflake.snowpark.types.Variant) =>
key.asInstanceOf[Object] -> InternalUtils.createVariant(value)
case (key, value) => key.asInstanceOf[Object] -> value.asInstanceOf[Object]
}.asJava
map
.map((e: (Any, Any)) => {
(e) match {
case (key, value: com.snowflake.snowpark.types.Variant) =>
key.asInstanceOf[Object] -> InternalUtils.createVariant(value)
case (key, value) => key.asInstanceOf[Object] -> value.asInstanceOf[Object]
}
})
.asJava

def serialize(obj: Any): Array[Byte] = {
val bos = new ByteArrayOutputStream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import net.snowflake.client.jdbc.internal.apache.arrow.vector.util.{

import java.util
import scala.collection.mutable
import scala.reflect.classTag
import scala.reflect.runtime.universe.TypeTag
import scala.collection.JavaConverters._

Expand Down Expand Up @@ -317,7 +318,7 @@ private[snowpark] class ServerConnection(

private[snowpark] def resultSetToRows(statement: Statement): Array[Row] = withValidConnection {
val iterator = resultSetToIterator(statement)._1
val buff = mutable.ArrayBuilder.make[Row]()
val buff = mutable.ArrayBuilder.make[Row](classTag[Row])
while (iterator.hasNext) {
buff += iterator.next()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ class UDXRegistrationHandler(session: Session) extends Logging {
if (actionID <= session.getLastCanceledID) {
throw ErrorMessage.MISC_QUERY_IS_CANCELLED()
}
allUrls
allUrls.toSeq
}
(allImports, targetJarStageLocation)
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com/snowflake/snowpark/internal/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ object Utils extends Logging {
// Define the compat scala version instead of reading from property file
// because it fails to read the property file in some environment such as
// VSCode worksheet.
val ScalaCompatVersion: String = "2.12"
val ScalaCompatVersion: String = "2.13"
// Minimum minor version. We require version to be greater than 2.12.9
val ScalaMinimumMinorVersion: String = "2.12.9"
val ScalaMinimumMinorVersion: String = "2.13.9"

// Minimum GS version for us to identify as Snowpark client
val MinimumGSVersionForSnowparkClientType: String = "5.20.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ class SnowflakePlan(
}
val supportAsyncMode = subqueryPlans.forall(_.supportAsyncMode)
SnowflakePlan(
preQueries :+ queries.last,
(preQueries :+ queries.last).toSeq,
newSchemaQuery,
newPostActions,
newPostActions.toSeq,
session,
sourcePlan,
supportAsyncMode)
Expand Down
9 changes: 7 additions & 2 deletions src/main/scala/com/snowflake/snowpark/types/Variant.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.fasterxml.jackson.databind.node.{ArrayNode, JsonNodeFactory, ObjectNo
import java.math.{BigDecimal => JavaBigDecimal, BigInteger => JavaBigInteger}
import java.sql.{Date, Time, Timestamp}
import java.util.{List => JavaList, Map => JavaMap}
import scala.jdk.FunctionConverters._
import scala.collection.JavaConverters._
import Variant._
import org.apache.commons.codec.binary.{Base64, Hex}
Expand Down Expand Up @@ -225,7 +226,7 @@ class Variant private[snowpark] (
*
* @since 0.2.0
*/
def this(list: JavaList[Object]) = this(list.asScala)
def this(list: JavaList[Object]) = this(list.asScala.toSeq)

/**
* Creates a Variant from array
Expand All @@ -244,7 +245,11 @@ class Variant private[snowpark] (
{
def mapToNode(map: JavaMap[Object, Object]): ObjectNode = {
val result = MAPPER.createObjectNode()
map.keySet().forEach(key => result.set(key.toString, objectToJsonNode(map.get(key))))
val consumer =
(key: Object) => result.set(key.toString, objectToJsonNode(map.get(key)))
map
.keySet()
.forEach(consumer.asJavaConsumer)
result
}
obj match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.snowflake.code_verification

import com.snowflake.snowpark.{CodeVerification, DataFrame}
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

// verify API Java and Scala API contain same functions
@CodeVerification
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.snowflake.code_verification

import com.snowflake.snowpark.CodeVerification
import org.scalatest.FunSuite
import org.scalatest.funsuite.{AnyFunSuite => FunSuite}

import scala.collection.mutable

Expand Down
6 changes: 5 additions & 1 deletion src/test/scala/com/snowflake/perf/PerfBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,18 @@ trait PerfBase extends SNTestBase {
test(testName) {
try {
writeResult(testName, timer(func))
succeed
} catch {
case ex: Exception =>
writeResult(testName, -1.0) // -1.0 if failed
throw ex
}
}
} else {
ignore(testName)(func)
ignore(testName) {
func
succeed
}
}
}
}
Loading