Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
Signed-off-by: Kenrick Yap <[email protected]>
  • Loading branch information
kenrickyap committed Dec 18, 2024
1 parent 9a7dace commit 543740d
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
}

protected def createGeoIpTestTable(testTable: String): Unit = {
sql(
s"""
sql(s"""
| CREATE TABLE $testTable
| (
| ip STRING,
Expand All @@ -782,8 +781,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| USING $tableType $tableOptions
|""".stripMargin)

sql(
s"""
sql(s"""
| INSERT INTO $testTable
| VALUES ('66.249.157.90', true),
| ('2a09:bac2:19f8:2ac3::', true),
Expand All @@ -793,8 +791,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
}

protected def createGeoIpTable(): Unit = {
sql(
s"""
sql(s"""
| CREATE TABLE geoip
| (
| cidr STRING,
Expand All @@ -813,8 +810,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| USING $tableType $tableOptions
|""".stripMargin)

sql(
s"""
sql(s"""
| INSERT INTO geoip
| VALUES (
| '66.249.157.0/24',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLGeoipITSuite
extends QueryTest
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {
Expand All @@ -34,8 +34,7 @@ class FlintSparkPPLGeoipITSuite
}

test("test geoip with no parameters") {
val frame = sql(
s"""
val frame = sql(s"""
| source = $testTable| where isValid = true | eval a = geoip(ip) | fields ip, a
| """.stripMargin)

Expand All @@ -44,37 +43,52 @@ class FlintSparkPPLGeoipITSuite

// Define the expected results
val expectedResults: Array[Row] = Array(
Row("66.249.157.90", Row("JM", "Jamaica", "North America", "14", "Saint Catherine Parish", "Portmore", "America/Jamaica", "17.9686,-76.8827")),
Row("2a09:bac2:19f8:2ac3::", Row("CA", "Canada", "North America", "PE", "Prince Edward Island", "Charlottetown", "America/Halifax", "46.2396,-63.1355"))
)
Row(
"66.249.157.90",
Row(
"JM",
"Jamaica",
"North America",
"14",
"Saint Catherine Parish",
"Portmore",
"America/Jamaica",
"17.9686,-76.8827")),
Row(
"2a09:bac2:19f8:2ac3::",
Row(
"CA",
"Canada",
"North America",
"PE",
"Prince Edward Island",
"Charlottetown",
"America/Halifax",
"46.2396,-63.1355")))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
}

test("test geoip with one parameters") {
val frame = sql(
s"""
val frame = sql(s"""
| source = $testTable| where isValid = true | eval a = geoip(ip, country_name) | fields ip, a
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("66.249.157.90", "Jamaica"),
Row("2a09:bac2:19f8:2ac3::", "Canada")
)
val expectedResults: Array[Row] =
Array(Row("66.249.157.90", "Jamaica"), Row("2a09:bac2:19f8:2ac3::", "Canada"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
}

test("test geoip with multiple parameters") {
val frame = sql(
s"""
val frame = sql(s"""
| source = $testTable| where isValid = true | eval a = geoip(ip, country_name, city_name) | fields ip, a
| """.stripMargin)

Expand All @@ -83,8 +97,7 @@ class FlintSparkPPLGeoipITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("66.249.157.90", Row("Jamaica", "Portmore")),
Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown"))
)
Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown")))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Join,
import org.apache.spark.sql.types.DataTypes

class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
extends SparkFunSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {
Expand All @@ -28,20 +28,18 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
private val pplParser = new PPLSyntaxParser()

private def getGeoIpQueryPlan(
ipAddress: UnresolvedAttribute,
left : LogicalPlan,
right : LogicalPlan,
projectionProperties : Alias
) : LogicalPlan = {
ipAddress: UnresolvedAttribute,
left: LogicalPlan,
right: LogicalPlan,
projectionProperties: Alias): LogicalPlan = {
val joinPlan = getJoinPlan(ipAddress, left, right)
getProjection(joinPlan, projectionProperties)
}

private def getJoinPlan(
ipAddress: UnresolvedAttribute,
left : LogicalPlan,
right : LogicalPlan
) : LogicalPlan = {
ipAddress: UnresolvedAttribute,
left: LogicalPlan,
right: LogicalPlan): LogicalPlan = {
val is_ipv4 = ScalaUDF(
SerializableUdf.geoIpUtils.isIpv4,
DataTypes.BooleanType,
Expand All @@ -50,8 +48,7 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
Option.empty,
Option.apply("is_ipv4"),
false,
true
)
true)
val ip_to_int = ScalaUDF(
SerializableUdf.geoIpUtils.ipToInt,
DataTypes.createDecimalType(38, 0),
Expand All @@ -60,29 +57,34 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
Option.empty,
Option.apply("ip_to_int"),
false,
true
)
true)

val t1 = SubqueryAlias("t1", left)
val t2 = SubqueryAlias("t2", right)

val joinCondition = And(
And(
GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.ip_range_start")),
LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))
),
EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4"))
)
LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))),
EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4")))
Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE)
}

private def getProjection(joinPlan : LogicalPlan, projectionProperties : Alias) : LogicalPlan = {
private def getProjection(joinPlan: LogicalPlan, projectionProperties: Alias): LogicalPlan = {
val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan)
val dropList = Seq(
"t2.country_iso_code", "t2.country_name", "t2.continent_name",
"t2.region_iso_code", "t2.region_name", "t2.city_name",
"t2.time_zone", "t2.location", "t2.cidr", "t2.ip_range_start", "t2.ip_range_end", "t2.ipv4"
).map(UnresolvedAttribute(_))
"t2.country_iso_code",
"t2.country_name",
"t2.continent_name",
"t2.region_iso_code",
"t2.region_name",
"t2.city_name",
"t2.time_zone",
"t2.location",
"t2.cidr",
"t2.ip_range_start",
"t2.ip_range_end",
"t2.ipv4").map(UnresolvedAttribute(_))
DataFrameDropColumns(dropList, projection)
}

Expand All @@ -98,16 +100,24 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
val sourceTable = UnresolvedRelation(seq("users"))
val geoTable = UnresolvedRelation(seq("geoip"))

val projectionStruct = CreateNamedStruct(Seq(
Literal("country_iso_code"), UnresolvedAttribute("t2.country_iso_code"),
Literal("country_name"), UnresolvedAttribute("t2.country_name"),
Literal("continent_name"), UnresolvedAttribute("t2.continent_name"),
Literal("region_iso_code"), UnresolvedAttribute("t2.region_iso_code"),
Literal("region_name"), UnresolvedAttribute("t2.region_name"),
Literal("city_name"), UnresolvedAttribute("t2.city_name"),
Literal("time_zone"), UnresolvedAttribute("t2.time_zone"),
Literal("location"), UnresolvedAttribute("t2.location")
))
val projectionStruct = CreateNamedStruct(
Seq(
Literal("country_iso_code"),
UnresolvedAttribute("t2.country_iso_code"),
Literal("country_name"),
UnresolvedAttribute("t2.country_name"),
Literal("continent_name"),
UnresolvedAttribute("t2.continent_name"),
Literal("region_iso_code"),
UnresolvedAttribute("t2.region_iso_code"),
Literal("region_name"),
UnresolvedAttribute("t2.region_name"),
Literal("city_name"),
UnresolvedAttribute("t2.city_name"),
Literal("time_zone"),
UnresolvedAttribute("t2.time_zone"),
Literal("location"),
UnresolvedAttribute("t2.location")))
val structProjection = Alias(projectionStruct, "a")()

val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection)
Expand Down Expand Up @@ -135,7 +145,6 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}


test("test geoip function - ipAddress col exist in geoip table") {
val context = new CatalystPlanContext

Expand All @@ -158,7 +167,7 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
test("test geoip function - duplicate parameters") {
val context = new CatalystPlanContext

val exception = intercept[IllegalStateException]{
val exception = intercept[IllegalStateException] {
planTransformer.visit(
plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name, country_name)"),
context)
Expand Down Expand Up @@ -197,10 +206,12 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
val ipAddress = UnresolvedAttribute("ip_address")
val sourceTable = UnresolvedRelation(seq("users"))
val geoTable = UnresolvedRelation(seq("geoip"))
val projectionStruct = CreateNamedStruct(Seq(
Literal("country_name"), UnresolvedAttribute("t2.country_name"),
Literal("location"), UnresolvedAttribute("t2.location")
))
val projectionStruct = CreateNamedStruct(
Seq(
Literal("country_name"),
UnresolvedAttribute("t2.country_name"),
Literal("location"),
UnresolvedAttribute("t2.location")))
val structProjection = Alias(projectionStruct, "a")()

val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection)
Expand All @@ -214,7 +225,9 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite

val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"),
plan(
pplParser,
"source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"),
context)

val ipAddress = UnresolvedAttribute("ip_address")
Expand All @@ -237,7 +250,9 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite

val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"),
plan(
pplParser,
"source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"),
context)

val ipAddress = UnresolvedAttribute("ip_address")
Expand Down

0 comments on commit 543740d

Please sign in to comment.