Skip to content

Commit

Permalink
Add in CORS functionality for supporting the short-lived URLs for cli…
Browse files Browse the repository at this point in the history
…ent-side cross origin access.
  • Loading branch information
josh-seidel-db committed Feb 27, 2023
1 parent 47d13e3 commit e0a64ef
Showing 1 changed file with 134 additions and 69 deletions.
203 changes: 134 additions & 69 deletions server/src/main/scala/io/delta/sharing/server/DeltaSharingService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.delta.sharing.server

import java.io.{ByteArrayOutputStream, File, FileNotFoundException}
import java.lang.reflect.Method
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.AccessDeniedException
import java.security.MessageDigest
Expand All @@ -31,7 +32,9 @@ import com.linecorp.armeria.common.auth.OAuth2Token
import com.linecorp.armeria.internal.server.ResponseConversionUtil
import com.linecorp.armeria.server.{Server, ServiceRequestContext}
import com.linecorp.armeria.server.annotation.{ConsumesJson, Default, ExceptionHandler, ExceptionHandlerFunction, Get, Head, Param, Post, ProducesJson}
import com.linecorp.armeria.server.annotation.decorator.CorsDecorator
import com.linecorp.armeria.server.auth.AuthService
import com.linecorp.armeria.server.cors.CorsService
import io.delta.standalone.internal.DeltaCDFErrors
import io.delta.standalone.internal.DeltaCDFIllegalArgumentException
import io.delta.standalone.internal.DeltaDataSource
Expand All @@ -42,7 +45,7 @@ import org.slf4j.LoggerFactory
import scalapb.json4s.Printer

import io.delta.sharing.server.config.ServerConfig
import io.delta.sharing.server.model.SingleAction
import io.delta.sharing.server.model.{ AddCDCFile, AddFile, AddFileForCDF, RemoveFile, SingleAction }
import io.delta.sharing.server.protocol._
import io.delta.sharing.server.util.JsonUtils

Expand All @@ -62,9 +65,9 @@ class DeltaSharingServiceExceptionHandler extends ExceptionHandlerFunction {
private val logger = LoggerFactory.getLogger(classOf[DeltaSharingServiceExceptionHandler])

override def handleException(
ctx: ServiceRequestContext,
req: HttpRequest,
cause: Throwable): HttpResponse = {
ctx: ServiceRequestContext,
req: HttpRequest,
cause: Throwable): HttpResponse = {
cause match {
// Handle exceptions caused by incorrect requests
case _: DeltaSharingNoSuchElementException =>
Expand Down Expand Up @@ -128,7 +131,7 @@ class DeltaSharingServiceExceptionHandler extends ExceptionHandlerFunction {
//
// valid json but may not be incorect field type
case (_: scalapb.json4s.JsonFormatException |
// invalid json
// invalid json
_: com.fasterxml.jackson.databind.JsonMappingException) =>
HttpResponse.of(
HttpStatus.BAD_REQUEST,
Expand Down Expand Up @@ -190,24 +193,31 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Get("/shares")
@ProducesJson
def listShares(
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSharesResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSharesResponse = processRequest {
val (shares, nextPageToken) = sharedTableManager.listShares(Option(pageToken), Some(maxResults))
ListSharesResponse(shares, nextPageToken)
}

@Get("/shares/{share}")
@ProducesJson
def getShare(@Param("share") share: String): GetShareResponse = processRequest {
def getShare(
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String): GetShareResponse = processRequest {
GetShareResponse(share = Some(sharedTableManager.getShare(share)))
}

@Get("/shares/{share}/schemas")
@ProducesJson
def listSchemas(
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSchemasResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListSchemasResponse = processRequest {
val (schemas, nextPageToken) =
sharedTableManager.listSchemas(share, Option(pageToken), Some(maxResults))
ListSchemasResponse(schemas, nextPageToken)
Expand All @@ -216,10 +226,12 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Get("/shares/{share}/schemas/{schema}/tables")
@ProducesJson
def listTables(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListTablesResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListTablesResponse = processRequest {
val (tables, nextPageToken) =
sharedTableManager.listTables(share, schema, Option(pageToken), Some(maxResults))
ListTablesResponse(tables, nextPageToken)
Expand All @@ -228,27 +240,27 @@ class DeltaSharingService(serverConfig: ServerConfig) {
@Get("/shares/{share}/all-tables")
@ProducesJson
def listAllTables(
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListAllTablesResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("maxResults") @Default("500") maxResults: Int,
@Param("pageToken") @Nullable pageToken: String): ListAllTablesResponse = processRequest {
val (tables, nextPageToken) =
sharedTableManager.listAllTables(share, Option(pageToken), Some(maxResults))
ListAllTablesResponse(tables, nextPageToken)
}

private def createHeadersBuilderForTableVersion(version: Long): ResponseHeadersBuilder = {
ResponseHeaders.builder(200).set(DELTA_TABLE_VERSION_HEADER, version.toString)
}

// TODO: deprecate HEAD request in favor of the GET request
@Head("/shares/{share}/schemas/{schema}/tables/{table}")
@Get("/shares/{share}/schemas/{schema}/tables/{table}/version")
def getTableVersion(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
@Param("startingTimestamp") @Nullable startingTimestamp: String
): HttpResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
@Param("startingTimestamp") @Nullable startingTimestamp: String
): HttpResponse = processRequest {
val tableConfig = sharedTableManager.getTable(share, schema, table)
if (startingTimestamp != null && !tableConfig.cdfEnabled) {
throw new DeltaSharingIllegalArgumentException("Reading table by version or timestamp is" +
Expand All @@ -261,18 +273,31 @@ class DeltaSharingService(serverConfig: ServerConfig) {
if (startingTimestamp != null && version < tableConfig.startVersion) {
throw new DeltaSharingIllegalArgumentException(
s"You can only query table data since version ${tableConfig.startVersion}." +
s"The provided timestamp($startingTimestamp) corresponds to $version."
s"The provided timestamp($startingTimestamp) corresponds to $version."
)
}
val headers = createHeadersBuilderForTableVersion(version).build()
val headersBuilder = ResponseHeaders.builder(HttpStatus.OK.code)

val corsService = CorsService.builder(serverConfig.getHost)
val setCorsResponseHeadersMethod = corsService.getClass
.getDeclaredMethod("setCorsResponseHeaders",
classOf[ServiceRequestContext], classOf[HttpRequest], classOf[ResponseHeadersBuilder])
setCorsResponseHeadersMethod.setAccessible(true)
setCorsResponseHeadersMethod
.invoke(corsService, serviceRequestContext, httpRequest, headersBuilder)

val headers = headersBuilder.set(DELTA_TABLE_VERSION_HEADER, version.toString)
.build()
HttpResponse.of(headers)
}

@Get("/shares/{share}/schemas/{schema}/tables/{table}/metadata")
def getMetadata(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String): HttpResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String): HttpResponse = processRequest {
import scala.collection.JavaConverters._
val tableConfig = sharedTableManager.getTable(share, schema, table)
val (v, actions) = deltaSharedTableLoader.loadTable(tableConfig).query(
Expand All @@ -283,16 +308,18 @@ class DeltaSharingService(serverConfig: ServerConfig) {
version = None,
timestamp = None,
startingVersion = None)
streamingOutput(Some(v), actions)
streamingOutput(serviceRequestContext, httpRequest, Some(v), actions)
}

@Post("/shares/{share}/schemas/{schema}/tables/{table}/query")
@ConsumesJson
def listFiles(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
request: QueryTableRequest): HttpResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
request: QueryTableRequest): HttpResponse = processRequest {
val numVersionParams = Seq(request.version, request.timestamp, request.startingVersion)
.filter(_.isDefined).size
if (numVersionParams > 1) {
Expand Down Expand Up @@ -337,21 +364,23 @@ class DeltaSharingService(serverConfig: ServerConfig) {
}
logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table " +
s"and sign ${actions.length - 2} urls for table $share/$schema/$table")
streamingOutput(Some(version), actions)
streamingOutput(serviceRequestContext, httpRequest, Some(version), actions)
}

@Get("/shares/{share}/schemas/{schema}/tables/{table}/changes")
@ConsumesJson
def listCdfFiles(
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
@Param("startingVersion") @Nullable startingVersion: String,
@Param("endingVersion") @Nullable endingVersion: String,
@Param("startingTimestamp") @Nullable startingTimestamp: String,
@Param("endingTimestamp") @Nullable endingTimestamp: String,
@Param("includeHistoricalMetadata") @Nullable includeHistoricalMetadata: String
): HttpResponse = processRequest {
serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
@Param("share") share: String,
@Param("schema") schema: String,
@Param("table") table: String,
@Param("startingVersion") @Nullable startingVersion: String,
@Param("endingVersion") @Nullable endingVersion: String,
@Param("startingTimestamp") @Nullable startingTimestamp: String,
@Param("endingTimestamp") @Nullable endingTimestamp: String,
@Param("includeHistoricalMetadata") @Nullable includeHistoricalMetadata: String
): HttpResponse = processRequest {
val start = System.currentTimeMillis
val tableConfig = sharedTableManager.getTable(share, schema, table)
if (!tableConfig.cdfEnabled) {
Expand All @@ -370,19 +399,55 @@ class DeltaSharingService(serverConfig: ServerConfig) {
)
logger.info(s"Took ${System.currentTimeMillis - start} ms to load the table cdf " +
s"and sign ${actions.length - 2} urls for table $share/$schema/$table")
streamingOutput(Some(v), actions)
streamingOutput(serviceRequestContext, httpRequest, Some(v), actions)
}

private def streamingOutput(version: Option[Long], actions: Seq[SingleAction]): HttpResponse = {
val headers = if (version.isDefined) {
createHeadersBuilderForTableVersion(version.get)
.set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE)
.build()
} else {
ResponseHeaders.builder(200)
.set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE)
.build()
private def streamingOutput(serviceRequestContext: ServiceRequestContext,
httpRequest: HttpRequest,
version: Option[Long],
actions: Seq[SingleAction]): HttpResponse = {
val headersBuilder = ResponseHeaders.builder(HttpStatus.OK.code)
.set(HttpHeaderNames.CONTENT_TYPE, DELTA_TABLE_METADATA_CONTENT_TYPE);
if (actions.nonEmpty) {
val urls: Seq[String] = actions.map(((e: SingleAction) => {
val a = e.unwrap
a.getClass match {
case v if v == classOf[AddFile] => a.asInstanceOf[AddFile].url
case v if v == classOf[AddFileForCDF] => a.asInstanceOf[AddFileForCDF].url
case v if v == classOf[AddCDCFile] => a.asInstanceOf[AddCDCFile].url
case v if v == classOf[RemoveFile] => a.asInstanceOf[RemoveFile].url
case _ => null
}
}): (SingleAction => String)).filter(_ != null)
val corsUrls = (serverConfig.getHost +: urls)
val corsService = CorsService.builder(corsUrls: _*)

/* From CorsService, the private method setCorsResponseHeaders is used to set the headers
via the builder and would be called if the library were properly utilized with static
service end-points. Since the short-lived URLs are dyanmic and change, the CORS headers
will change and need to be dynamic each time. Note this does not change the headers on
the Blob Storage that still need to allow the origin. The code exposes the method and
calls the method to allow the Delta Sharing Server to properly send back the expected
headers. For more information please see:
https://advancedweb.hu/how-to-solve-cors-problems-when-redirecting-to-s3-signed-urls/
and
https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
NOTE: The setCorsResponseHeadersMethod can be made into a field on the class to then
have only the method invoked during each call of streamingOutput.
*/
val setCorsResponseHeadersMethod = corsService.getClass
.getDeclaredMethod("setCorsResponseHeaders",
classOf[ServiceRequestContext], classOf[HttpRequest], classOf[ResponseHeadersBuilder])
setCorsResponseHeadersMethod.setAccessible(true)

setCorsResponseHeadersMethod
.invoke(corsService, serviceRequestContext, httpRequest, headersBuilder)
}

if (version.isDefined) headersBuilder.set(DELTA_TABLE_VERSION_HEADER, version.get.toString)
val headers = headersBuilder.build()

ResponseConversionUtil.streamingFrom(
actions.asJava.stream(),
headers,
Expand All @@ -393,7 +458,7 @@ class DeltaSharingService(serverConfig: ServerConfig) {
out.write('\n')
HttpData.wrap(out.toByteArray)
},
ServiceRequestContext.current().blockingTaskExecutor())
serviceRequestContext.blockingTaskExecutor())
}
}

Expand Down Expand Up @@ -476,10 +541,10 @@ object DeltaSharingService {
}

private def checkCDFOptionsValidity(
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Unit = {
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Unit = {
// check if we have both version and timestamp parameters
if (startingVersion.isDefined && startingTimestamp.isDefined) {
throw DeltaCDFErrors.multipleCDFBoundary("starting")
Expand Down Expand Up @@ -510,16 +575,16 @@ object DeltaSharingService {
}

private[server] def getCdfOptionsMap(
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Map[String, String] = {
startingVersion: Option[String],
endingVersion: Option[String],
startingTimestamp: Option[String],
endingTimestamp: Option[String]): Map[String, String] = {
checkCDFOptionsValidity(startingVersion, endingVersion, startingTimestamp, endingTimestamp)

(startingVersion.map(DeltaDataSource.CDF_START_VERSION_KEY -> _) ++
endingVersion.map(DeltaDataSource.CDF_END_VERSION_KEY -> _) ++
startingTimestamp.map(DeltaDataSource.CDF_START_TIMESTAMP_KEY -> _) ++
endingTimestamp.map(DeltaDataSource.CDF_END_TIMESTAMP_KEY -> _)).toMap
endingVersion.map(DeltaDataSource.CDF_END_VERSION_KEY -> _) ++
startingTimestamp.map(DeltaDataSource.CDF_START_TIMESTAMP_KEY -> _) ++
endingTimestamp.map(DeltaDataSource.CDF_END_TIMESTAMP_KEY -> _)).toMap
}

def main(args: Array[String]): Unit = {
Expand Down

0 comments on commit e0a64ef

Please sign in to comment.