Skip to content

Commit

Permalink
fix(scio-cosmosdb): Fix the CosmosDbBoundedReader#getCurrent, is now …
Browse files Browse the repository at this point in the history
…idempotent and add @experimental annotations
  • Loading branch information
Miuler committed Feb 2, 2023
1 parent 8be709d commit 0fe0b7d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ import org.bson.Document
trait CosmosDbIO[T] extends ScioIO[T] {}

case class ReadCosmosDdIO(
endpoint: String = null,
key: String = null,
database: String = null,
container: String = null,
query: String = null
endpoint: String,
key: String,
database: String,
container: String,
query: String
) extends CosmosDbIO[Document] {
override type ReadP = Unit
override type WriteP = Nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@
package com.spotify.scio.cosmosdb.read

import com.azure.cosmos.models.CosmosQueryRequestOptions
import com.azure.cosmos.{CosmosClient, CosmosClientBuilder}
import org.apache.beam.sdk.annotations.Experimental
import org.apache.beam.sdk.annotations.Experimental.Kind
import com.azure.cosmos.{ CosmosClient, CosmosClientBuilder }
import com.spotify.scio.annotations.experimental
import org.apache.beam.sdk.io.BoundedSource
import org.bson.Document
import org.slf4j.LoggerFactory

@Experimental(Kind.SOURCE_SINK)
@experimental
private[read] class CosmosDbBoundedReader(cosmosSource: CosmosDbBoundedSource)
extends BoundedSource.BoundedReader[Document] {
private val log = LoggerFactory.getLogger(getClass)
private var maybeClient: Option[CosmosClient] = None
private var maybeIterator: Option[java.util.Iterator[Document]] = None
@volatile private var current: Option[Document] = None
@volatile private var recordsReturned = 0L

override def start(): Boolean = {
maybeClient = Some(
Expand All @@ -55,19 +56,26 @@ private[read] class CosmosDbBoundedReader(cosmosSource: CosmosDbBoundedSource)
.iterator()
}

true
advance()
}

override def advance(): Boolean = maybeIterator.exists(_.hasNext)
override def advance(): Boolean = maybeIterator match {
case Some(iterator) if iterator.hasNext =>
current = Some(iterator.next())
recordsReturned += 1
true
case _ =>
false
}

override def getCurrent: Document =
maybeIterator
.filter(_.hasNext)
// .map(iterator => new Document(iterator.next()))
.map(_.next())
.orNull
override def getCurrent: Document = current.orNull

override def getCurrentSource: CosmosDbBoundedSource = cosmosSource

override def close(): Unit = maybeClient.foreach(_.close())
override def close(): Unit = {
log.info("Closing reader after reading {} records.", recordsReturned)
maybeClient.foreach(_.close())
maybeClient = None
maybeIterator = None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@

package com.spotify.scio.cosmosdb.read

import org.apache.beam.sdk.annotations.Experimental
import org.apache.beam.sdk.annotations.Experimental.Kind
import org.apache.beam.sdk.coders.{Coder, SerializableCoder}
import com.spotify.scio.annotations.experimental
import org.apache.beam.sdk.coders.{ Coder, SerializableCoder }
import org.apache.beam.sdk.io.BoundedSource
import org.apache.beam.sdk.options.PipelineOptions
import org.bson.Document
Expand All @@ -29,7 +28,7 @@ import java.util.Collections
/**
* A CosmosDB Core (SQL) API [[BoundedSource]] reading [[Document]] from a given instance.
*/
@Experimental(Kind.SOURCE_SINK)
@experimental
private[read] class CosmosDbBoundedSource(private[read] val readCosmos: CosmosDbRead)
extends BoundedSource[Document] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,29 @@

package com.spotify.scio.cosmosdb.read

import org.apache.beam.sdk.annotations.Experimental
import org.apache.beam.sdk.annotations.Experimental.Kind
import com.spotify.scio.annotations.experimental
import org.apache.beam.sdk.io.Read
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.values.{PBegin, PCollection}
import org.apache.beam.sdk.values.{ PBegin, PCollection }
import org.bson.Document
import org.slf4j.LoggerFactory

/** A [[PTransform]] to read data from CosmosDB Core (SQL) API. */
@Experimental(Kind.SOURCE_SINK)
@experimental
private[cosmosdb] case class CosmosDbRead(
endpoint: String = null,
key: String = null,
database: String = null,
container: String = null,
endpoint: String,
key: String,
database: String,
container: String,
query: String = null
) extends PTransform[PBegin, PCollection[Document]] {

private val log = LoggerFactory.getLogger(classOf[CosmosDbRead])

/** Create new ReadCosmos based into previous ReadCosmos, modifying the endpoint */
def withCosmosEndpoint(endpoint: String): CosmosDbRead = this.copy(endpoint = endpoint)

def withCosmosKey(key: String): CosmosDbRead = this.copy(key = key)

def withDatabase(database: String): CosmosDbRead = this.copy(database = database)

def withQuery(query: String): CosmosDbRead = this.copy(query = query)

def withContainer(container: String): CosmosDbRead = this.copy(container = container)

override def expand(input: PBegin): PCollection[Document] = {
log.debug(s"Read CosmosDB with endpoint: $endpoint and query: $query")
validate()

// input.getPipeline.apply(Read.from(new CosmosSource(this)))
input.apply(Read.from(new CosmosDbBoundedSource(this)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.spotify.scio.cosmosdb.syntax

import com.spotify.scio.ScioContext
import com.spotify.scio.annotations.experimental
import com.spotify.scio.cosmosdb.ReadCosmosDdIO
import com.spotify.scio.values.SCollection
import org.bson.Document
Expand All @@ -28,6 +29,7 @@ trait ScioContextSyntax {

final class CosmosDbScioContextOps(private val sc: ScioContext) extends AnyVal {

@experimental
/**
* Read data from CosmosDB CORE (SQL) API
*
Expand Down

0 comments on commit 0fe0b7d

Please sign in to comment.