Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in CORS functionality for supporting the short-lived URLs for client-side cross origin access #275

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 121 additions & 56 deletions server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.delta.sharing.server

import java.io.{ByteArrayOutputStream, File, FileNotFoundException}
import java.lang.reflect.Method
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.AccessDeniedException
import java.security.MessageDigest
Expand All @@ -31,7 +32,9 @@ import com.linecorp.armeria.common.auth.OAuth2Token
import com.linecorp.armeria.internal.server.ResponseConversionUtil
import com.linecorp.armeria.server.{Server, ServiceRequestContext}
import com.linecorp.armeria.server.annotation.{ConsumesJson, Default, ExceptionHandler, ExceptionHandlerFunction, Get, Head, Param, Post, ProducesJson}
import com.linecorp.armeria.server.annotation.decorator.CorsDecorator
import com.linecorp.armeria.server.auth.AuthService
import com.linecorp.armeria.server.cors.CorsService
import io.delta.standalone.internal.DeltaCDFErrors
import io.delta.standalone.internal.DeltaCDFIllegalArgumentException
import io.delta.standalone.internal.DeltaDataSource
Expand All @@ -42,7 +45,7 @@ import org.slf4j.LoggerFactory
import scalapb.json4s.Printer

import io.delta.sharing.server.config.ServerConfig
import io.delta.sharing.server.model.SingleAction
import io.delta.sharing.server.model.{ AddCDCFile, AddFile, AddFileForCDF, RemoveFile, SingleAction }
import io.delta.sharing.server.protocol._
import io.delta.sharing.server.util.JsonUtils

Expand All @@ -62,9 +65,9 @@ class DeltaSharingServiceExceptionHandler extends ExceptionHandlerFunction {
private val logger = LoggerFactory.getLogger(classOf[DeltaSharingServiceExceptionHandler])

override def handleException(
ctx: ServiceRequestContext,
req: HttpRequest,
cause: Throwable): HttpResponse = {
ctx: ServiceRequestContext,
req: HttpRequest,
cause: Throwable): HttpResponse = {
cause match {
// Handle exceptions caused by incorrect requests
case _: DeltaSharingNoSuchElementException =>
Expand Down Expand Up @@ -128,7 +131,7 @@ class DeltaSharingServiceExceptionHandler extends ExceptionHandlerFunction {
//
// valid json but may not be incorect field type
case (_: scalapb.json4s.JsonFormatException |
// invalid json
// invalid json
_: com.fasterxml.jackson.databind.JsonMappingException) =>
HttpResponse.of(
HttpStatus.BAD_REQUEST,
Expand Down Expand Up @@ -190,24 +193,31 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Get("/shares")
@ProducesJson
def listShares(
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSharesResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSharesResponse = processRequest {
val (shares, nextPageToken) = sharedTableManager.listShares(Option(pageToken), Some(maxResults))
ListSharesResponse(shares, nextPageToken)
}

@Get("/shares/{share}")
@ProducesJson
def getShare(@Param("share") share: String): GetShareResponse = processRequest {
def getShare(
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String): GetShareResponse = processRequest {
GetShareResponse(share = Some(sharedTableManager.getShare(share)))
}

@Get("/shares/{share}/schemas")
@ProducesJson
def listSchemas(
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSchemasResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSchemasResponse = processRequest {
val (schemas, nextPageToken) =
sharedTableManager.listSchemas(share, Option(pageToken), Some(maxResults))
ListSchemasResponse(schemas, nextPageToken)
Expand All @@ -216,10 +226,12 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Get("/shares/{share}/schemas/{schema}/tables")
@ProducesJson
def listTables(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListTablesResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListTablesResponse = processRequest {
val (tables, nextPageToken) =
sharedTableManager.listTables(share, schema, Option(pageToken), Some(maxResults))
ListTablesResponse(tables, nextPageToken)
Expand All @@ -228,22 +240,22 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Get("/shares/{share}/all-tables")
@ProducesJson
def listAllTables(
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListAllTablesResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListAllTablesResponse = processRequest {
val (tables, nextPageToken) =
sharedTableManager.listAllTables(share, Option(pageToken), Some(maxResults))
ListAllTablesResponse(tables, nextPageToken)
}

private def createHeadersBuilderForTableVersion(version: Long): ResponseHeadersBuilder = {
ResponseHeaders.builder(200).set(DELTA_TABLE_VERSION_HEADER, version.toString)
}

// TODO: deprecate HEAD request in favor of the GET request
@Head("/shares/{share}/schemas/{schema}/tables/{table}")
@Get("/shares/{share}/schemas/{schema}/tables/{table}/version")
def getTableVersion(
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
Expand All @@ -261,18 +273,31 @@ class DeltaSharingService(serverConfig: ServerConfig) {
if (startingTimestamp != null && version < tableConfig.startVersion) {
throw new DeltaSharingIllegalArgumentException(
s"You can only query table data since version ${tableConfig.startVersion}." +
s"The provided timestamp($startingTimestamp) corresponds to $version."
s"The provided timestamp($startingTimestamp) corresponds to $version."
)
}
val headers = createHeadersBuilderForTableVersion(version).build()
val headersBuilder = ResponseHeaders.builder(HttpStatus.OK.code)

val corsService = CorsService.builder(serverConfig.getHost)
val setCorsResponseHeadersMethod = corsService.getClass
.getDeclaredMethod("setCorsResponseHeaders",
classOf[ServiceRequestContext], classOf[HttpRequest], classOf[ResponseHeadersBuilder])
setCorsResponseHeadersMethod.setAccessible(true)
setCorsResponseHeadersMethod
.invoke(corsService, serviceRequestContext, httpRequest, headersBuilder)

val headers = headersBuilder.set(DELTA_TABLE_VERSION_HEADER, version.toString)
.build()
HttpResponse.of(headers)
}

@Get("/shares/{share}/schemas/{schema}/tables/{table}/metadata")
def getMetadata(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String): HttpResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String): HttpResponse = processRequest {
import scala.collection.JavaConverters._
val tableConfig = sharedTableManager.getTable(share, schema, table)
val (v, actions) = deltaSharedTableLoader.loadTable(tableConfig).query(
Expand All @@ -283,16 +308,18 @@ class DeltaSharingService(serverConfig: ServerConfig) {
version = None,
timestamp = None,
startingVersion = None)
streamingOutput(Some(v), actions)
streamingOutput(serviceRequestContext, httpRequest, Some(v), actions)
}

@Post("/shares/{share}/schemas/{schema}/tables/{table}/query")
@ConsumesJson
def listFiles(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
request: QueryTableRequest): HttpResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
request: QueryTableRequest): HttpResponse = processRequest {
val numVersionParams = Seq(request.version, request.timestamp, request.startingVersion)
.filter(_.isDefined).size
if (numVersionParams > 1) {
Expand Down Expand Up @@ -337,12 +364,14 @@ class DeltaSharingService(serverConfig: ServerConfig) {
}
logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table " +
s"and sign ${actions.length - 2} urls for table $share/$schema/$table")
streamingOutput(Some(version), actions)
streamingOutput(serviceRequestContext, httpRequest, Some(version), actions)
}

@Get("/shares/{share}/schemas/{schema}/tables/{table}/changes")
@ConsumesJson
def listCdfFiles(
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
Expand All @@ -351,7 +380,7 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Param("startingTimestamp") @Nullable startingTimestamp: String,
@Param("endingTimestamp") @Nullable endingTimestamp: String,
@Param("includeHistoricalMetadata") @Nullable includeHistoricalMetadata: String
): HttpResponse = processRequest {
): HttpResponse = processRequest {
val start = System.currentTimeMillis
val tableConfig = sharedTableManager.getTable(share, schema, table)
if (!tableConfig.cdfEnabled) {
Expand All @@ -370,19 +399,55 @@ class DeltaSharingService(serverConfig: ServerConfig) {
)
logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table cdf " +
s"and sign ${actions.length - 2} urls for table $share/$schema/$table")
streamingOutput(Some(v), actions)
streamingOutput(serviceRequestContext, httpRequest, Some(v), actions)
}

private def streamingOutput(version: Option[Long], actions: Seq[SingleAction]): HttpResponse = {
val headers = if (version.isDefined) {
createHeadersBuilderForTableVersion(version.get)
.set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE)
.build()
} else {
ResponseHeaders.builder(200)
.set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE)
.build()
private def streamingOutput(serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
version: Option[Long],
actions: Seq[SingleAction]): HttpResponse = {
val headersBuilder = ResponseHeaders.builder(HttpStatus.OK.code)
.set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE);
if (actions.nonEmpty) {
val urls: Seq[String] = actions.map(((e: SingleAction) => {
val a = e.unwrap
a.getClass match {
case v if v == classOf[AddFile] => a.asInstanceOf[AddFile].url
case v if v == classOf[AddFileForCDF] => a.asInstanceOf[AddFileForCDF].url
case v if v == classOf[AddCDCFile] => a.asInstanceOf[AddCDCFile].url
case v if v == classOf[RemoveFile] => a.asInstanceOf[RemoveFile].url
case _ => null
}
}): (SingleAction => String)).filter(_ != null)
val corsUrls = (serverConfig.getHost +: urls)
val corsService = CorsService.builder(corsUrls: _*)

/* From CorsService, the private method setCorsResponseHeaders is used to set the headers
via the builder and would be called if the library were properly utilized with static
service end-points. Since the short-lived URLs are dyanmic and change, the CORS headers
will change and need to be dynamic each time. Note this does not change the headers on
the Blob Storage that still need to allow the origin. The code exposes the method and
calls the method to allow the Delta Sharing Server to properly send back the expected
headers. For more information please see:
https://advancedweb.hu/how-to-solve-cors-problems-when-redirecting-to-s3-signed-urls/
and
https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS

NOTE: The setCorsResponseHeadersMethod can be made into a field on the class to then
have only the method invoked during each call of streamingOutput.
*/
val setCorsResponseHeadersMethod = corsService.getClass
.getDeclaredMethod("setCorsResponseHeaders",
classOf[ServiceRequestContext], classOf[HttpRequest], classOf[ResponseHeadersBuilder])
setCorsResponseHeadersMethod.setAccessible(true)

setCorsResponseHeadersMethod
.invoke(corsService, serviceRequestContext, httpRequest, headersBuilder)
}

if (version.isDefined) headersBuilder.set(DELTA_TABLE_VERSION_HEADER, version.get.toString)
val headers = headersBuilder.build()

ResponseConversionUtil.streamingFrom(
actions.asJava.stream(),
headers,
Expand All @@ -393,7 +458,7 @@ class DeltaSharingService(serverConfig: ServerConfig) {
out.write('\n')
HttpData.wrap(out.toByteArray)
},
ServiceRequestContext.current().blockingTaskExecutor())
serviceRequestContext.blockingTaskExecutor())
}
}

Expand Down Expand Up @@ -476,10 +541,10 @@ object DeltaSharingService {
}

private def checkCDFOptionsValidity(
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Unit = {
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Unit = {
// check if we have both version and timestamp parameters
if (startingVersion.isDefined && startingTimestamp.isDefined) {
throw DeltaCDFErrors.multipleCDFBoundary("starting")
Expand Down Expand Up @@ -510,16 +575,16 @@ object DeltaSharingService {
}

private[server] def getCdfOptionsMap(
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Map[String, String] = {
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Map[String, String] = {
checkCDFOptionsValidity(startingVersion, endingVersion, startingTimestamp, endingTimestamp)

(startingVersion.map(DeltaDataSource.CDF_START_VERSION_KEY -> _) ++
endingVersion.map(DeltaDataSource.CDF_END_VERSION_KEY -> _) ++
startingTimestamp.map(DeltaDataSource.CDF_START_TIMESTAMP_KEY -> _) ++
endingTimestamp.map(DeltaDataSource.CDF_END_TIMESTAMP_KEY -> _)).toMap
endingVersion.map(DeltaDataSource.CDF_END_VERSION_KEY -> _) ++
startingTimestamp.map(DeltaDataSource.CDF_START_TIMESTAMP_KEY -> _) ++
endingTimestamp.map(DeltaDataSource.CDF_END_TIMESTAMP_KEY -> _)).toMap
}

def main(args: Array[String]): Unit = {
Expand Down