Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(dsl): Support source filtering in search requests #196

Merged
merged 11 commits into from
May 4, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import zio.elasticsearch.ElasticAggregation.{maxAggregation, multipleAggregation
import zio.elasticsearch.ElasticHighlight.highlight
import zio.elasticsearch.ElasticQuery._
import zio.elasticsearch.ElasticSort.sortBy
import zio.elasticsearch.domain.{TestDocument, TestSubDocument}
import zio.elasticsearch.domain.{PartialTestDocument, TestDocument, TestSubDocument}
import zio.elasticsearch.executor.Executor
import zio.elasticsearch.executor.response.MaxAggregationResponse
import zio.elasticsearch.query.sort.SortMode.Max
Expand All @@ -31,6 +31,7 @@ import zio.elasticsearch.request.{CreationOutcome, DeletionOutcome}
import zio.elasticsearch.result.{Item, UpdateByQueryResult}
import zio.elasticsearch.script.Script
import zio.json.ast.Json.{Arr, Str}
import zio.schema.codec.JsonCodec
import zio.stream.{Sink, ZSink}
import zio.test._
import zio.test.TestAspect._
Expand Down Expand Up @@ -506,6 +507,76 @@ object HttpExecutorSpec extends IntegrationSpec {
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("search for documents with source filtering") {
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
for {
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
_ <- Executor.execute(
ElasticRequest.upsert[TestDocument](firstSearchIndex, firstDocumentId, firstDocument)
)
_ <- Executor.execute(
ElasticRequest.upsert[TestDocument](firstSearchIndex, secondDocumentId, secondDocument)
)
_ <- Executor.execute(
ElasticRequest
.upsert[TestDocument](firstSearchIndex, thirdDocumentId, thirdDocument)
.refreshTrue
)
query = range(TestDocument.doubleField).gte(100.0)
res <- Executor
.execute(ElasticRequest.search(firstSearchIndex, query).includes[PartialTestDocument])
items <- res.items
} yield assert(items.map(item => Right(item.raw)))(
hasSameElements(
List(firstDocument, secondDocument, thirdDocument).map(document =>
TestDocument.schema.migrate(PartialTestDocument.schema).flatMap(_(document)).flatMap {
partialDocument =>
JsonCodec.jsonEncoder(PartialTestDocument.schema).toJsonAST(partialDocument)
}
)
)
)
}
} @@ around(
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("fail if an excluded source field is attempted to be decoded") {
checkOnce(genDocumentId, genTestDocument, genDocumentId, genTestDocument, genDocumentId, genTestDocument) {
(firstDocumentId, firstDocument, secondDocumentId, secondDocument, thirdDocumentId, thirdDocument) =>
val result =
for {
_ <- Executor.execute(ElasticRequest.deleteByQuery(firstSearchIndex, matchAll))
_ <- Executor.execute(
ElasticRequest.upsert[TestDocument](firstSearchIndex, firstDocumentId, firstDocument)
)
_ <- Executor.execute(
ElasticRequest.upsert[TestDocument](firstSearchIndex, secondDocumentId, secondDocument)
)
_ <- Executor.execute(
ElasticRequest
.upsert[TestDocument](firstSearchIndex, thirdDocumentId, thirdDocument)
.refreshTrue
)
query = range(TestDocument.doubleField).gte(100.0)
_ <- Executor
.execute(ElasticRequest.search(firstSearchIndex, query).excludes("intField"))
.documentAs[TestDocument]
} yield ()

assertZIO(result.exit)(
fails(
isSubtype[Exception](
assertException("Could not parse all documents successfully: .intField(missing)")
)
)
)
}
} @@ around(
Executor.execute(ElasticRequest.createIndex(firstSearchIndex)),
Executor.execute(ElasticRequest.deleteIndex(firstSearchIndex)).orDie
),
test("fail if any of results cannot be decoded") {
checkOnce(genDocumentId, genDocumentId, genTestDocument, genTestSubDocument) {
(documentId, subDocumentId, document, subDocument) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ object ElasticRequest {
index = index,
query = query,
sortBy = Chunk.empty,
excluded = None,
from = None,
highlights = None,
included = None,
routing = None,
searchAfter = None,
size = None
Expand Down Expand Up @@ -257,8 +259,10 @@ object ElasticRequest {
query = query,
aggregation = aggregation,
sortBy = Chunk.empty,
excluded = None,
from = None,
highlights = None,
included = None,
routing = None,
searchAfter = None,
size = None
Expand Down Expand Up @@ -553,8 +557,9 @@ object ElasticRequest {
extends ElasticRequest[SearchResult]
with HasFrom[SearchRequest]
with HasRouting[SearchRequest]
with HasSize[SearchRequest]
with HasSort[SearchRequest]
with HasSize[SearchRequest] {
with HasSourceFiltering[SearchRequest] {
def aggregate(aggregation: ElasticAggregation): SearchAndAggregateRequest

def highlights(value: Highlights): SearchRequest
Expand All @@ -566,32 +571,46 @@ object ElasticRequest {
index: IndexName,
query: ElasticQuery[_],
sortBy: Chunk[Sort],
excluded: Option[Chunk[String]],
from: Option[Int],
highlights: Option[Highlights],
included: Option[Chunk[String]],
routing: Option[Routing],
searchAfter: Option[Json],
size: Option[Int]
) extends SearchRequest { self =>

def aggregate(aggregation: ElasticAggregation): SearchAndAggregateRequest =
SearchAndAggregate(
index = index,
query = query,
aggregation = aggregation,
sortBy = sortBy,
excluded = excluded,
from = from,
highlights = highlights,
included = included,
routing = routing,
searchAfter = None,
size = size
)

def excludes(field: String, fields: String*): SearchRequest =
self.copy(excluded = excluded.map(_ ++ (field +: fields)).orElse(Some(field +: Chunk.fromIterable(fields))))

def from(value: Int): SearchRequest =
self.copy(from = Some(value))

def highlights(value: Highlights): SearchRequest =
self.copy(highlights = Some(value))

def includes(field: String, fields: String*): SearchRequest =
self.copy(included = included.map(_ ++ (field +: fields)).orElse(Some(field +: Chunk.fromIterable(fields))))

def includes[A](implicit schema: Schema.Record[A]): SearchRequest = {
val fields = Chunk.fromIterable(getFieldNames(schema))
self.copy(included = included.map(_ ++ fields).orElse(Some(fields)))
}

def routing(value: Routing): SearchRequest =
self.copy(routing = Some(value))

Expand All @@ -616,7 +635,18 @@ object ElasticRequest {
val sortJson: Json =
if (self.sortBy.nonEmpty) Obj("sort" -> Arr(self.sortBy.map(_.paramsToJson): _*)) else Obj()

fromJson merge sizeJson merge highlightsJson merge sortJson merge self.query.toJson merge searchAfterJson
val sourceJson: Json =
(included, excluded) match {
case (None, None) => Obj()
case (included, excluded) =>
Obj("_source" -> {
val includes = included.fold(Obj())(included => Obj("includes" -> Arr(included.map(_.toJson): _*)))
val excludes = excluded.fold(Obj())(excluded => Obj("excludes" -> Arr(excluded.map(_.toJson): _*)))
includes merge excludes
})
}

fromJson merge sizeJson merge highlightsJson merge sortJson merge self.query.toJson merge searchAfterJson merge sourceJson
}
}

Expand All @@ -625,7 +655,8 @@ object ElasticRequest {
with HasFrom[SearchAndAggregateRequest]
with HasRouting[SearchAndAggregateRequest]
with HasSize[SearchAndAggregateRequest]
with HasSort[SearchAndAggregateRequest] {
with HasSort[SearchAndAggregateRequest]
with HasSourceFiltering[SearchAndAggregateRequest] {
def highlights(value: Highlights): SearchAndAggregateRequest

def searchAfter(value: Json): SearchAndAggregateRequest
Expand All @@ -636,18 +667,31 @@ object ElasticRequest {
query: ElasticQuery[_],
aggregation: ElasticAggregation,
sortBy: Chunk[Sort],
excluded: Option[Chunk[String]],
from: Option[Int],
highlights: Option[Highlights],
included: Option[Chunk[String]],
routing: Option[Routing],
searchAfter: Option[Json],
size: Option[Int]
) extends SearchAndAggregateRequest { self =>
def excludes(field: String, fields: String*): SearchAndAggregateRequest =
self.copy(excluded = excluded.map(_ ++ (field +: fields)).orElse(Some(field +: Chunk.fromIterable(fields))))

def from(value: Int): SearchAndAggregateRequest =
self.copy(from = Some(value))

def highlights(value: Highlights): SearchAndAggregateRequest =
self.copy(highlights = Some(value))

def includes(field: String, fields: String*): SearchAndAggregateRequest =
self.copy(included = included.map(_ ++ (field +: fields)).orElse(Some(field +: Chunk.fromIterable(fields))))

def includes[A](implicit schema: Schema.Record[A]): SearchAndAggregateRequest = {
val fields = Chunk.fromIterable(getFieldNames(schema))
self.copy(included = included.map(_ ++ fields).orElse(Some(fields)))
}

def routing(value: Routing): SearchAndAggregateRequest =
self.copy(routing = Some(value))

Expand All @@ -672,13 +716,25 @@ object ElasticRequest {
val sortJson: Json =
if (self.sortBy.nonEmpty) Obj("sort" -> Arr(self.sortBy.map(_.paramsToJson): _*)) else Obj()

val sourceJson: Json =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. (Arnold's comment)

(included, excluded) match {
case (None, None) => Obj()
case (included, excluded) =>
Obj("_source" -> {
val includes = included.fold(Obj())(included => Obj("includes" -> Arr(included.map(_.toJson): _*)))
val excludes = excluded.fold(Obj())(excluded => Obj("excludes" -> Arr(excluded.map(_.toJson): _*)))
includes merge excludes
})
}

fromJson merge
sizeJson merge
highlightsJson merge
sortJson merge
self.query.toJson merge
aggregation.toJson merge
searchAfterJson
searchAfterJson merge
sourceJson
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2022 LambdaWorks
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package zio.elasticsearch.request.options

import zio.schema.Schema

private[elasticsearch] trait HasSourceFiltering[R <: HasSourceFiltering[R]] {

/**
* Specifies one or more fields to be excluded in the response of a [[zio.elasticsearch.ElasticRequest.SearchRequest]]
* or a [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]].
*
* @param field
* a field to be excluded
* @param fields
* fields to be excluded
* @return
* an instance of a [[zio.elasticsearch.ElasticRequest.SearchRequest]] or a
* [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]] with specified fields to be excluded.
*/
def excludes(field: String, fields: String*): R

/**
* Specifies one or more fields to be included in the response of a [[zio.elasticsearch.ElasticRequest.SearchRequest]]
* or a [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]].
*
* @param field
* a field to be included
* @param fields
* fields to be included
* @return
* an instance of a [[zio.elasticsearch.ElasticRequest.SearchRequest]] or a
* [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]] with specified fields to be included.
*/
def includes(field: String, fields: String*): R

/**
* Specifies fields to be included in the response of a [[zio.elasticsearch.ElasticRequest.SearchRequest]] or a
* [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]] based on the schema of a case class.
*
* @tparam A
* a case class whose fields will be included in the response
* @param schema
* a record schema of [[A]]
* @return
* an instance of a [[zio.elasticsearch.ElasticRequest.SearchRequest]] or a
* [[zio.elasticsearch.ElasticRequest.SearchAndAggregateRequest]] with specified fields to be excluded.
*/
def includes[A](implicit schema: Schema.Record[A]): R

protected final def getFieldNames(schema: Schema.Record[_]): List[String] = {
def extractInnerSchema(schema: Schema[_]): Schema[_] =
Schema.force(schema) match {
case schema: Schema.Sequence[_, _, _] => Schema.force(schema.elementSchema)
case schema => schema
}

def loop(schema: Schema.Record[_], prefix: Option[String]): List[String] =
schema.fields.toList.flatMap { field =>
extractInnerSchema(field.schema) match {
case schema: Schema.Record[_] => loop(schema, prefix.map(_ + "." + field.name).orElse(Some(field.name)))
case _ => List(prefix.fold[String](field.name)(_ + "." + field.name))
}
}

loop(schema, None)
}
}
Loading