From 7b7f32231fb0f4df88dae3c8bffbf485ba30ba06 Mon Sep 17 00:00:00 2001
From: Matteo Merli <mmerli@apache.org>
Date: Tue, 8 Oct 2024 10:18:33 -0700
Subject: [PATCH] Client API changes for secondary idx (#545)

---
 oxia/async_client_impl.go     |  29 ++++++----
 oxia/internal/model/model.go  |   2 +
 oxia/options_list.go          |  21 +++++++
 oxia/options_put.go           |  20 +++++++
 oxia/options_range_scan.go    |   2 +-
 oxia/proto_utils.go           |  11 ++++
 oxia/sync_client_impl_test.go | 102 ++++++++++++++++++++++++++++++++++
 7 files changed, 174 insertions(+), 13 deletions(-)

diff --git a/oxia/async_client_impl.go b/oxia/async_client_impl.go
index 849f9966..67841a0b 100644
--- a/oxia/async_client_impl.go
+++ b/oxia/async_client_impl.go
@@ -137,6 +137,7 @@ func (c *clientImpl) Put(key string, value []byte, options ...PutOption) <-chan
 		SequenceKeysDeltas: opts.sequenceKeysDeltas,
 		PartitionKey:       opts.partitionKey,
 		Callback:           callback,
+		SecondaryIndexes:   toSecondaryIndexes(opts.secondaryIndexes),
 	}
 	if opts.ephemeral {
 		putCall.ClientIdentity = &c.options.identity
@@ -319,11 +320,13 @@ func processAllGetResponses(key string, results []*proto.GetResponse, comparison
 	close(ch)
 }
 
-func (c *clientImpl) listFromShard(ctx context.Context, minKeyInclusive string, maxKeyExclusive string, shardId int64, ch chan<- ListResult) {
+func (c *clientImpl) listFromShard(ctx context.Context, minKeyInclusive string, maxKeyExclusive string, shardId int64, secondaryIndexName *string,
+	ch chan<- ListResult) {
 	request := &proto.ListRequest{
-		Shard:          &shardId,
-		StartInclusive: minKeyInclusive,
-		EndExclusive:   maxKeyExclusive,
+		Shard:              &shardId,
+		StartInclusive:     minKeyInclusive,
+		EndExclusive:       maxKeyExclusive,
+		SecondaryIndexName: secondaryIndexName,
 	}
 
 	client, err := c.executor.ExecuteList(ctx, request)
@@ -355,7 +358,7 @@ func (c *clientImpl) List(ctx context.Context, minKeyInclusive string, maxKeyExc
 		// If the partition key is specified, we only need to make the request to one shard
 		shardId := c.getShardForKey("", opts)
 		go func() {
-			c.listFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardId, ch)
+			c.listFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardId, opts.secondaryIndexName, ch)
 			close(ch)
 		}()
 	} else {
@@ -368,7 +371,7 @@ func (c *clientImpl) List(ctx context.Context, minKeyInclusive string, maxKeyExc
 			go func() {
 				defer wg.Done()
 
-				c.listFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardIdPtr, ch)
+				c.listFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardIdPtr, opts.secondaryIndexName, ch)
 			}()
 		}
 
@@ -381,11 +384,13 @@ func (c *clientImpl) List(ctx context.Context, minKeyInclusive string, maxKeyExc
 	return ch
 }
 
-func (c *clientImpl) rangeScanFromShard(ctx context.Context, minKeyInclusive string, maxKeyExclusive string, shardId int64, ch chan<- GetResult) {
+func (c *clientImpl) rangeScanFromShard(ctx context.Context, minKeyInclusive string, maxKeyExclusive string, shardId int64, secondaryIndexName *string,
+	ch chan<- GetResult) {
 	request := &proto.RangeScanRequest{
-		Shard:          &shardId,
-		StartInclusive: minKeyInclusive,
-		EndExclusive:   maxKeyExclusive,
+		Shard:              &shardId,
+		StartInclusive:     minKeyInclusive,
+		EndExclusive:       maxKeyExclusive,
+		SecondaryIndexName: secondaryIndexName,
 	}
 
 	client, err := c.executor.ExecuteRangeScan(ctx, request)
@@ -421,7 +426,7 @@ func (c *clientImpl) RangeScan(ctx context.Context, minKeyInclusive string, maxK
 		// If the partition key is specified, we only need to make the request to one shard
 		shardId := c.getShardForKey("", opts)
 		go func() {
-			c.rangeScanFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardId, outCh)
+			c.rangeScanFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardId, opts.secondaryIndexName, outCh)
 		}()
 	} else {
 		// Do the list on all shards and aggregate the responses
@@ -433,7 +438,7 @@ func (c *clientImpl) RangeScan(ctx context.Context, minKeyInclusive string, maxK
 			ch := make(chan GetResult)
 			channels[i] = ch
 			go func() {
-				c.rangeScanFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardIdPtr, ch)
+				c.rangeScanFromShard(ctx, minKeyInclusive, maxKeyExclusive, shardIdPtr, opts.secondaryIndexName, ch)
 			}()
 		}
 
diff --git a/oxia/internal/model/model.go b/oxia/internal/model/model.go
index 3d4c1eaa..01ea81f4 100644
--- a/oxia/internal/model/model.go
+++ b/oxia/internal/model/model.go
@@ -26,6 +26,7 @@ type PutCall struct {
 	SessionId          *int64
 	ClientIdentity     *string
 	PartitionKey       *string
+	SecondaryIndexes   []*proto.SecondaryIndex
 	Callback           func(*proto.PutResponse, error)
 }
 
@@ -56,6 +57,7 @@ func (r PutCall) ToProto() *proto.PutRequest {
 		ClientIdentity:    r.ClientIdentity,
 		PartitionKey:      r.PartitionKey,
 		SequenceKeyDelta:  r.SequenceKeysDeltas,
+		SecondaryIndexes:  r.SecondaryIndexes,
 	}
 }
 
diff --git a/oxia/options_list.go b/oxia/options_list.go
index bf4c1bab..a3add2fb 100644
--- a/oxia/options_list.go
+++ b/oxia/options_list.go
@@ -16,11 +16,14 @@ package oxia
 
 type listOptions struct {
 	baseOptions
+
+	secondaryIndexName *string
 }
 
 // ListOption represents an option for the [SyncClient.List] operation.
 type ListOption interface {
 	applyList(opts *listOptions)
+	applyRangeScan(opts *rangeScanOptions)
 }
 
 func newListOptions(opts []ListOption) *listOptions {
@@ -30,3 +33,21 @@ func newListOptions(opts []ListOption) *listOptions {
 	}
 	return listOpts
 }
+
+type useIndex struct {
+	indexName string
+}
+
+func (u *useIndex) applyList(opts *listOptions) {
+	opts.secondaryIndexName = &u.indexName
+}
+
+func (u *useIndex) applyRangeScan(opts *rangeScanOptions) {
+	opts.secondaryIndexName = &u.indexName
+}
+
+// UseIndex let the users specify a different index to follow for the
+// Note: The returned list will contain they primary keys of the records
+func UseIndex(indexName string) ListOption {
+	return &useIndex{indexName}
+}
diff --git a/oxia/options_put.go b/oxia/options_put.go
index ba8d0a2e..fba6a532 100644
--- a/oxia/options_put.go
+++ b/oxia/options_put.go
@@ -21,6 +21,7 @@ type putOptions struct {
 	expectedVersion    *int64
 	ephemeral          bool
 	sequenceKeysDeltas []uint64
+	secondaryIndexes   []*secondaryIdxOption
 }
 
 // PutOption represents an option for the [SyncClient.Put] operation.
@@ -95,3 +96,22 @@ func (s *sequenceKeysDeltas) applyPut(opts *putOptions) {
 func SequenceKeysDeltas(delta ...uint64) PutOption {
 	return &sequenceKeysDeltas{delta}
 }
+
+type secondaryIdxOption struct {
+	indexName    string
+	secondaryKey string
+}
+
+func (s *secondaryIdxOption) applyPut(opts *putOptions) {
+	opts.secondaryIndexes = append(opts.secondaryIndexes, s)
+}
+
+// SecondaryIndex let the users specify additional keys to index the record
+// Index names are arbitrary strings and can be used in `List` and
+// `RangeScan` requests.
+// Secondary keys are not required to be unique.
+// Multiple secondary indexes can be passed on the same record, even
+// reusing multiple times the same indexName.
+func SecondaryIndex(indexName string, secondaryKey string) PutOption {
+	return &secondaryIdxOption{indexName, secondaryKey}
+}
diff --git a/oxia/options_range_scan.go b/oxia/options_range_scan.go
index 5b29dc01..fd719a59 100644
--- a/oxia/options_range_scan.go
+++ b/oxia/options_range_scan.go
@@ -15,7 +15,7 @@
 package oxia
 
 type rangeScanOptions struct {
-	baseOptions
+	listOptions
 }
 
 // RangeScanOption represents an option for the [SyncClient.RangeScan] operation.
diff --git a/oxia/proto_utils.go b/oxia/proto_utils.go
index f59fd4e7..d86272d5 100644
--- a/oxia/proto_utils.go
+++ b/oxia/proto_utils.go
@@ -97,3 +97,14 @@ func toError(status proto.Status) error {
 		return ErrUnknownStatus
 	}
 }
+
+func toSecondaryIndexes(secondaryIndexes []*secondaryIdxOption) (res []*proto.SecondaryIndex) {
+	for _, si := range secondaryIndexes {
+		res = append(res, &proto.SecondaryIndex{
+			IndexName:    si.indexName,
+			SecondaryKey: si.secondaryKey,
+		})
+	}
+
+	return res
+}
diff --git a/oxia/sync_client_impl_test.go b/oxia/sync_client_impl_test.go
index 38b981c9..6fec5645 100644
--- a/oxia/sync_client_impl_test.go
+++ b/oxia/sync_client_impl_test.go
@@ -16,6 +16,10 @@ package oxia
 
 import (
 	"context"
+	"fmt"
+	"github.com/streamnative/oxia/server"
+	"log/slog"
+	"strings"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -91,3 +95,101 @@ func assertCancellable(t *testing.T, operationFunc func(context.Context) error)
 
 	assert.ErrorIs(t, <-errCh, context.Canceled)
 }
+
+func TestSyncClientImpl_SecondaryIndexes(t *testing.T) {
+	config := server.NewTestConfig(t.TempDir())
+	// Test with multiple shards to ensure correctness across shards
+	config.NumShards = 1
+	standaloneServer, err := server.NewStandalone(config)
+	assert.NoError(t, err)
+
+	serviceAddress := fmt.Sprintf("localhost:%d", standaloneServer.RpcPort())
+	client, err := NewSyncClient(serviceAddress)
+	assert.NoError(t, err)
+
+	// ////////////////////////////////////////////////////////////////////////
+
+	ctx := context.Background()
+	for i := 0; i < 10; i++ {
+		primKey := fmt.Sprintf("%c", 'a'+i)
+		val := fmt.Sprintf("%d", i)
+		slog.Info("Adding record",
+			slog.String("key", primKey),
+			slog.String("value", val),
+		)
+		_, _, _ = client.Put(ctx, primKey, []byte(val), SecondaryIndex("val-idx", val))
+	}
+
+	// ////////////////////////////////////////////////////////////////////////
+
+	l, err := client.List(ctx, "1", "4", UseIndex("val-idx"))
+	assert.NoError(t, err)
+	assert.Equal(t, []string{"b", "c", "d"}, l)
+
+	// ////////////////////////////////////////////////////////////////////////
+
+	resCh := client.RangeScan(ctx, "1", "4", UseIndex("val-idx"))
+	i := 1
+	for res := range resCh {
+		assert.NoError(t, res.Err)
+
+		primKey := fmt.Sprintf("%c", 'a'+i)
+		val := fmt.Sprintf("%d", i)
+
+		slog.Info("Expected record",
+			slog.String("expected-key", primKey),
+			slog.String("expected-value", val),
+			slog.String("received-key", res.Key),
+			slog.String("received-value", string(res.Value)),
+		)
+		assert.Equal(t, primKey, res.Key)
+		assert.Equal(t, val, string(res.Value))
+		i++
+	}
+
+	assert.Equal(t, 4, i)
+
+	assert.NoError(t, client.Close())
+	assert.NoError(t, standaloneServer.Close())
+}
+
+func TestSyncClientImpl_SecondaryIndexesRepeated(t *testing.T) {
+	config := server.NewTestConfig(t.TempDir())
+	// Test with multiple shards to ensure correctness across shards
+	config.NumShards = 1
+	standaloneServer, err := server.NewStandalone(config)
+	assert.NoError(t, err)
+
+	serviceAddress := fmt.Sprintf("localhost:%d", standaloneServer.RpcPort())
+	client, err := NewSyncClient(serviceAddress)
+	assert.NoError(t, err)
+
+	// ////////////////////////////////////////////////////////////////////////
+
+	ctx := context.Background()
+	for i := 0; i < 10; i++ {
+		primKey := fmt.Sprintf("/%c", 'a'+i)
+		val := fmt.Sprintf("%c", 'a'+i)
+		slog.Info("Adding record",
+			slog.String("key", primKey),
+			slog.String("value", val),
+		)
+		_, _, _ = client.Put(ctx, primKey, []byte(val),
+			SecondaryIndex("val-idx", val),
+			SecondaryIndex("val-idx", strings.ToUpper(val)),
+		)
+	}
+
+	// ////////////////////////////////////////////////////////////////////////
+
+	l, err := client.List(ctx, "b", "e", UseIndex("val-idx"))
+	assert.NoError(t, err)
+	assert.Equal(t, []string{"/b", "/c", "/d"}, l)
+
+	l, err = client.List(ctx, "I", "d", UseIndex("val-idx"))
+	assert.NoError(t, err)
+	assert.Equal(t, []string{"/i", "/j", "/a", "/b", "/c"}, l)
+
+	assert.NoError(t, client.Close())
+	assert.NoError(t, standaloneServer.Close())
+}