diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e7d0cdbc7..7942bc476 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,15 +6,17 @@ jobs: runs-on: windows-latest strategy: matrix: - compiler: ["Visual Studio 17 2022"] + compiler: ["Visual Studio 17 2022", "MinGW Makefiles"] fail-fast: false + env: + EXTRA_FLAGS: "${{ matrix.compiler == 'Visual Studio 17 2022' && '-Ax64' || '' }}" steps: - name: Checkout repository uses: actions/checkout@v4 - name: Build Reindexer run: | mkdir build && cd build - cmake -G "${{matrix.compiler}}" .. -Ax64 + cmake -G "${{matrix.compiler}}" -DBUILD_ANN_INDEXES=builtin .. $EXTRA_FLAGS cmake --build . --config Release cmake --build . --config Release --target face cmake --build . --config Release --target swagger @@ -30,7 +32,7 @@ jobs: runs-on: windows-2019 strategy: matrix: - compiler: ["Visual Studio 16 2019", "MinGW Makefiles"] + compiler: ["Visual Studio 16 2019"] fail-fast: false steps: - name: Checkout repository @@ -38,7 +40,7 @@ jobs: - name: Build Reindexer run: | mkdir build && cd build - cmake -G "${{matrix.compiler}}" .. + cmake -G "${{matrix.compiler}}" -DBUILD_ANN_INDEXES=none .. cmake --build . --config Release cmake --build . --config Release --target face cmake --build . --config Release --target swagger @@ -183,68 +185,70 @@ jobs: if [[ -z "$SANITIZER" ]]; then go test -timeout 15m ./test/... -bench . -benchmem -benchtime 100ms -seedcount 50000 else - go test -timeout 35m ./test/... -bench . -benchmem -benchtime 100ms -seedcount 50000 + export TSAN_OPTIONS="halt_on_error=1 suppressions=$PWD/cpp_src/gtests/tsan.suppressions" + go test -timeout 35m ./test/... -bench . -benchmem -benchtime 100ms -seedcount 50000 -tags tiny_vectors fi else cd build ctest --verbose fi - test-pyreindexer: - strategy: - matrix: - os: [ubuntu-22.04, ubuntu-24.04] - fail-fast: false - runs-on: ${{matrix.os}} - needs: build - if: always() - env: - OS: ${{matrix.os}} - steps: - - name: Download ${{matrix.os}} Artifacts - uses: actions/download-artifact@v4 - with: - name: ${{matrix.os}} - - name: 'Untar Artifacts' - run: tar -xvf artifacts.tar - - name: Prepare Environment - run: | - if [[ $OS == ubuntu* ]]; then - sudo ./dependencies.sh - python3 -m pip install setuptools build - else - ./dependencies.sh - fi - - name: Install Reindexer - run: | - cd build - if [[ $OS == ubuntu* ]]; then - sudo dpkg -i reindexer-4-dev*.deb - sudo apt-get install -f - sudo dpkg -i reindexer-4-server*.deb - sudo apt-get install -f - else - for f in reindexer-*.tar.gz; do tar -xvzf "$f"; done - cp -R ./usr/local/include/reindexer /usr/local/include/reindexer - cp -R ./usr/local/lib/reindexer /usr/local/lib/reindexer - cp ./usr/local/lib/libreindexer.a /usr/local/lib/libreindexer.a - cp ./usr/local/lib/libreindexer_server_library.a /usr/local/lib/libreindexer_server_library.a - cp ./usr/local/lib/libreindexer_server_resources.a /usr/local/lib/libreindexer_server_resources.a - cp ./usr/local/lib/pkgconfig/libreindexer.pc /usr/local/lib/pkgconfig/libreindexer.pc - cp ./usr/local/lib/pkgconfig/libreindexer_server.pc /usr/local/lib/pkgconfig/libreindexer_server.pc - cp ./usr/local/bin/reindexer_server /usr/local/bin/reindexer_server - cp ./usr/local/etc/reindexer.conf.pkg /usr/local/etc/reindexer.conf.pkg - fi - - name: Clone PyReindexer - uses: actions/checkout@v4 - with: - repository: restream/reindexer-py - - name: Install PyReindexer - run: | - python -m build - python -m pip install . - - name: Test PyReindexer - run: | - cd pyreindexer - ../.github/workflows/test.sh +# TODO: Reenable after binding's compatibility fix +# test-pyreindexer: +# strategy: +# matrix: +# os: [ubuntu-22.04, ubuntu-24.04] +# fail-fast: false +# runs-on: ${{matrix.os}} +# needs: build +# if: always() +# env: +# OS: ${{matrix.os}} +# steps: +# - name: Download ${{matrix.os}} Artifacts +# uses: actions/download-artifact@v4 +# with: +# name: ${{matrix.os}} +# - name: 'Untar Artifacts' +# run: tar -xvf artifacts.tar +# - name: Prepare Environment +# run: | +# if [[ $OS == ubuntu* ]]; then +# sudo ./dependencies.sh +# python3 -m pip install setuptools build +# else +# ./dependencies.sh +# fi +# - name: Install Reindexer +# run: | +# cd build +# if [[ $OS == ubuntu* ]]; then +# sudo dpkg -i reindexer-4-dev*.deb +# sudo apt-get install -f +# sudo dpkg -i reindexer-4-server*.deb +# sudo apt-get install -f +# else +# for f in reindexer-*.tar.gz; do tar -xvzf "$f"; done +# cp -R ./usr/local/include/reindexer /usr/local/include/reindexer +# cp -R ./usr/local/lib/reindexer /usr/local/lib/reindexer +# cp ./usr/local/lib/libreindexer.a /usr/local/lib/libreindexer.a +# cp ./usr/local/lib/libreindexer_server_library.a /usr/local/lib/libreindexer_server_library.a +# cp ./usr/local/lib/libreindexer_server_resources.a /usr/local/lib/libreindexer_server_resources.a +# cp ./usr/local/lib/pkgconfig/libreindexer.pc /usr/local/lib/pkgconfig/libreindexer.pc +# cp ./usr/local/lib/pkgconfig/libreindexer_server.pc /usr/local/lib/pkgconfig/libreindexer_server.pc +# cp ./usr/local/bin/reindexer_server /usr/local/bin/reindexer_server +# cp ./usr/local/etc/reindexer.conf.pkg /usr/local/etc/reindexer.conf.pkg +# fi +# - name: Clone PyReindexer +# uses: actions/checkout@v4 +# with: +# repository: restream/reindexer-py +# - name: Install PyReindexer +# run: | +# python -m build +# python -m pip install . +# - name: Test PyReindexer +# run: | +# cd pyreindexer +# ../.github/workflows/test.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 68580f3d5..899ba6df5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,11 @@ -cmake_minimum_required(VERSION 3.10..3.13) +cmake_minimum_required(VERSION 3.18) project(reindexer) enable_testing() set(CMAKE_DISABLE_IN_SOURCE_BUILD ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +set(CMAKE_PCH_WARN_INVALID OFF) set(REINDEXER_SOURCE_PATH ${PROJECT_SOURCE_DIR}/cpp_src) add_subdirectory(cpp_src) diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index e8a8e5275..000000000 --- a/appveyor.yml +++ /dev/null @@ -1,43 +0,0 @@ -version: '{build}' - -# Uncomment this to enable the fast build environment if your account does not -# support it automatically: -image: Visual Studio 2019 - -environment: - matrix: - - BUILD_TYPE: Release - COMPILER: MinGW-w64 - PLATFORM: x64 - TOOLCHAIN: x86_64-8.1.0-posix-seh-rt_v6-rev0 - APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - -# - BUILD_TYPE: Release -# COMPILER: MinGW -# PLATFORM: Win32 -# TOOLCHAIN: i686-8.1.0-posix-dwarf-rt_v6-rev0 -# APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019 - -build_script: - - git describe --tags - - mkdir build - - cd build - - if [%COMPILER%]==[MinGW] set PATH=C:\MinGW-w64\%TOOLCHAIN%\mingw32\bin;%PATH:C:\Program Files\Git\usr\bin;=% - - if [%COMPILER%]==[MinGW-w64] set PATH=C:\MinGW-w64\%TOOLCHAIN%\mingw64\bin;%PATH:C:\Program Files\Git\usr\bin;=% - - if [%COMPILER%]==[MinGW] cmake -G "MinGW Makefiles" -DCMAKE_PREFIX_PATH=C:\mingw-w64\%TOOLCHAIN% -DCMAKE_CXX_STANDARD=17 -DCMAKE_BUILD_TYPE=%BUILD_TYPE% .. - - if [%COMPILER%]==[MinGW-w64] cmake -G "MinGW Makefiles" -DCMAKE_PREFIX_PATH=C:\mingw-w64\%TOOLCHAIN% -DCMAKE_CXX_STANDARD=17 -DCMAKE_BUILD_TYPE=%BUILD_TYPE% .. - - - cmake --build . --config %BUILD_TYPE% - - cmake --build . --config %BUILD_TYPE% --target face - - cmake --build . --config %BUILD_TYPE% --target swagger - - cpack - -artifacts: - - path: build/*.exe - name: reindexer_server - - -#test_script: - -on_success: -#- cd C:\ diff --git a/bindings.go b/bindings.go index c9d2f14be..06241edcd 100644 --- a/bindings.go +++ b/bindings.go @@ -11,9 +11,9 @@ import ( "github.com/prometheus/client_golang/prometheus" otelattr "go.opentelemetry.io/otel/attribute" - "github.com/restream/reindexer/v4/bindings" - "github.com/restream/reindexer/v4/bindings/builtinserver/config" - "github.com/restream/reindexer/v4/cjson" + "github.com/restream/reindexer/v5/bindings" + "github.com/restream/reindexer/v5/bindings/builtinserver/config" + "github.com/restream/reindexer/v5/cjson" ) const ( diff --git a/bindings/builtin/builtin.go b/bindings/builtin/builtin.go index f73bc5fe3..0b6aceafe 100644 --- a/bindings/builtin/builtin.go +++ b/bindings/builtin/builtin.go @@ -18,8 +18,8 @@ import ( "time" "unsafe" - "github.com/restream/reindexer/v4/bindings" - "github.com/restream/reindexer/v4/cjson" + "github.com/restream/reindexer/v5/bindings" + "github.com/restream/reindexer/v5/cjson" ) const defCgoLimit = 2000 @@ -250,12 +250,13 @@ func (binding *Builtin) Init(u []url.URL, eh bindings.EventsHandler, options ... caps := *bindings.DefaultBindingCapabilities(). WithResultsWithShardIDs(true). WithQrIdleTimeouts(true). - WithIncarnationTags(true) + WithIncarnationTags(true). + WithFloatRank(true) ccaps := C.BindingCapabilities{ caps: C.int64_t(caps.Value), } - return err2go(C.reindexer_connect_v4(binding.rx, str2c(u[0].Host+u[0].Path), opts, str2c(bindings.ReindexerVersion), ccaps)) + return err2go(C.reindexer_connect(binding.rx, str2c(u[0].Host+u[0].Path), opts, str2c(bindings.ReindexerVersion), ccaps)) } func (binding *Builtin) StartWatchOnCtx(ctx context.Context) (CCtxWrapper, error) { diff --git a/bindings/builtin/builtin_posix.go b/bindings/builtin/builtin_posix.go index edcd9a7bf..42c467467 100644 --- a/bindings/builtin/builtin_posix.go +++ b/bindings/builtin/builtin_posix.go @@ -4,7 +4,7 @@ package builtin // #cgo pkg-config: libreindexer -// #cgo CXXFLAGS: -std=c++17 -g -O2 -Wall -Wpedantic -Wextra +// #cgo CXXFLAGS: -std=c++20 -g -O2 -Wall -Wpedantic -Wextra // #cgo CFLAGS: -std=c99 -g -O2 -Wall -Wpedantic -Wno-unused-variable // #cgo LDFLAGS: -g import "C" diff --git a/bindings/builtin/builtin_windows.go b/bindings/builtin/builtin_windows.go index 93df6f576..baff93cf5 100644 --- a/bindings/builtin/builtin_windows.go +++ b/bindings/builtin/builtin_windows.go @@ -3,7 +3,7 @@ package builtin -// #cgo CXXFLAGS: -std=c++17 -g -O2 -Wall -Wpedantic -Wextra -I../../cpp_src +// #cgo CXXFLAGS: -std=c++20 -g -O2 -Wall -Wpedantic -Wextra -I../../cpp_src // #cgo CFLAGS: -std=c99 -g -O2 -Wall -Wpedantic -Wno-unused-variable -I../../cpp_src -// #cgo LDFLAGS: -L${SRCDIR}/../../build/cpp_src/ -lreindexer -lleveldb -lsnappy -g -lshlwapi -ldbghelp -lws2_32 +// #cgo LDFLAGS: -L${SRCDIR}/../../build/cpp_src/ -lreindexer -lleveldb -lsnappy -g -lgomp -lshlwapi -ldbghelp -lws2_32 import "C" diff --git a/bindings/builtin/cgoeventshandler.go b/bindings/builtin/cgoeventshandler.go index 2906a9142..63f046c13 100644 --- a/bindings/builtin/cgoeventshandler.go +++ b/bindings/builtin/cgoeventshandler.go @@ -12,7 +12,7 @@ import ( "time" "unsafe" - "github.com/restream/reindexer/v4/bindings" + "github.com/restream/reindexer/v5/bindings" ) type CGOEventsHandler struct { diff --git a/bindings/builtin/posix_config.go.in b/bindings/builtin/posix_config.go.in index 2cecac704..da67c6e29 100644 --- a/bindings/builtin/posix_config.go.in +++ b/bindings/builtin/posix_config.go.in @@ -3,7 +3,7 @@ package builtin -// #cgo CXXFLAGS: -std=c++17 -g -O2 -Wall -Wpedantic -Wextra @cgo_cxx_flags@ +// #cgo CXXFLAGS: -std=c++20 -g -O2 -Wall -Wpedantic -Wextra @cgo_cxx_flags@ // #cgo CFLAGS: -std=c99 -g -O2 -Wall -Wpedantic -Wno-unused-variable @cgo_c_flags@ // #cgo LDFLAGS: @cgo_ld_flags@ -g import "C" diff --git a/bindings/builtinserver/builtinserver.go b/bindings/builtinserver/builtinserver.go index 9451b27f1..64a30c0df 100644 --- a/bindings/builtinserver/builtinserver.go +++ b/bindings/builtinserver/builtinserver.go @@ -12,9 +12,9 @@ import ( "time" "unsafe" - "github.com/restream/reindexer/v4/bindings" - "github.com/restream/reindexer/v4/bindings/builtin" - "github.com/restream/reindexer/v4/bindings/builtinserver/config" + "github.com/restream/reindexer/v5/bindings" + "github.com/restream/reindexer/v5/bindings/builtin" + "github.com/restream/reindexer/v5/bindings/builtinserver/config" ) var defaultStartupTimeout time.Duration = time.Minute * 3 diff --git a/bindings/builtinserver/builtinserver_posix.go b/bindings/builtinserver/builtinserver_posix.go index acc365bb9..7066fea45 100644 --- a/bindings/builtinserver/builtinserver_posix.go +++ b/bindings/builtinserver/builtinserver_posix.go @@ -4,7 +4,7 @@ package builtinserver // #cgo pkg-config: libreindexer_server -// #cgo CXXFLAGS: -std=c++17 -g -O2 -Wall -Wpedantic -Wextra +// #cgo CXXFLAGS: -std=c++20 -g -O2 -Wall -Wpedantic -Wextra // #cgo CFLAGS: -std=c99 -g -O2 -Wall -Wpedantic -Wno-unused-variable // #cgo LDFLAGS: -g import "C" diff --git a/bindings/builtinserver/builtinserver_windows.go b/bindings/builtinserver/builtinserver_windows.go index 6a10df3c4..bd69952d9 100644 --- a/bindings/builtinserver/builtinserver_windows.go +++ b/bindings/builtinserver/builtinserver_windows.go @@ -3,7 +3,7 @@ package builtinserver -// #cgo CXXFLAGS: -std=c++17 -g -O2 -Wall -Wpedantic -Wextra -I../../cpp_src +// #cgo CXXFLAGS: -std=c++20 -g -O2 -Wall -Wpedantic -Wextra -I../../cpp_src // #cgo CFLAGS: -std=c99 -g -O2 -Wall -Wpedantic -Wno-unused-variable -I../../cpp_src -// #cgo LDFLAGS: -L${SRCDIR}/../../build/cpp_src/ -L${SRCDIR}/../../build/cpp_src/server/ -lreindexer_server_library -lreindexer -lreindexer_server_resources -lleveldb -lsnappy -g -lstdc++ -lshlwapi -ldbghelp -lws2_32 +// #cgo LDFLAGS: -L${SRCDIR}/../../build/cpp_src/ -L${SRCDIR}/../../build/cpp_src/server/ -lreindexer_server_library -lreindexer -lreindexer_server_resources -lleveldb -lsnappy -g -lgomp -lstdc++ -lshlwapi -ldbghelp -lws2_32 import "C" diff --git a/bindings/builtinserver/posix_config.go.in b/bindings/builtinserver/posix_config.go.in index 5cd111530..e7da0cca1 100644 --- a/bindings/builtinserver/posix_config.go.in +++ b/bindings/builtinserver/posix_config.go.in @@ -3,7 +3,7 @@ package builtinserver -// #cgo CXXFLAGS: -std=c++17 -g -O2 -Wall -Wpedantic -Wextra @cgo_cxx_flags@ +// #cgo CXXFLAGS: -std=c++20 -g -O2 -Wall -Wpedantic -Wextra @cgo_cxx_flags@ // #cgo CFLAGS: -std=c99 -g -O2 -Wall -Wpedantic -Wno-unused-variable @cgo_c_flags@ // #cgo LDFLAGS: @cgo_ld_flags@ -g import "C" diff --git a/bindings/consts.go b/bindings/consts.go index 3135f9b84..5d43f0405 100644 --- a/bindings/consts.go +++ b/bindings/consts.go @@ -46,16 +46,18 @@ const ( OpAnd = 2 OpNot = 3 - ValueInt64 = 0 - ValueDouble = 1 - ValueString = 2 - ValueBool = 3 - ValueNull = 4 - ValueInt = 8 - ValueUndefined = 9 - ValueComposite = 10 - ValueTuple = 11 - ValueUuid = 12 + ValueInt64 = 0 + ValueDouble = 1 + ValueString = 2 + ValueBool = 3 + ValueNull = 4 + ValueInt = 8 + ValueUndefined = 9 + ValueComposite = 10 + ValueTuple = 11 + ValueUuid = 12 + ValueFloatVector = 13 + ValueFloat = 14 QueryCondition = 0 QueryDistinct = 1 @@ -88,6 +90,15 @@ const ( QueryAlwaysTrueCondition = 28 QuerySubQueryCondition = 29 QueryFieldSubQueryCondition = 30 + QueryLocal = 31 + QueryKnnCondition = 32 + + KnnQueryTypeBase = 0 + KnnQueryTypeBruteForce = 1 + KnnQueryTypeHnsw = 2 + KnnQueryTypeIvf = 3 + + KnnQueryParamsVersion = 0 LeftJoin = 0 InnerJoin = 1 @@ -116,6 +127,7 @@ const ( QueryResultShardingVersion = 3 QueryResultShardId = 4 QueryResultIncarnationTags = 5 + QueryResultRankFormat = 6 QueryStrictModeNotSet = 0 QueryStrictModeNone = 1 @@ -130,7 +142,7 @@ const ( ResultsWithPayloadTypes = 0x10 ResultsWithItemID = 0x20 - ResultsWithPercents = 0x40 + ResultsWithRank = 0x40 ResultsWithNsID = 0x80 ResultsWithJoined = 0x100 ResultsWithShardId = 0x800 @@ -154,6 +166,9 @@ const ( BindingCapabilityQrIdleTimeouts = 1 BindingCapabilityResultsWithShardIDs = 1 << 1 BindingCapabilityNamespaceIncarnations = 1 << 2 + BindingCapabilityComplexRank = 1 << 3 + + RankFormatSingleFloat = 0 ErrOK = 0 ErrParseSQL = 1 @@ -190,6 +205,9 @@ const ( ErrQrUIDMissmatch = 36 ErrSystem = 37 ErrAssert = 38 + ErrParseYAML = 39 + ErrNamespaceOverwritten = 40 + ErrVersion = 41 ) const ( diff --git a/bindings/cproto/connection.go b/bindings/cproto/connection.go index ba818ba4e..291cdf29a 100644 --- a/bindings/cproto/connection.go +++ b/bindings/cproto/connection.go @@ -15,8 +15,8 @@ import ( "sync/atomic" "time" - "github.com/restream/reindexer/v4/bindings" - "github.com/restream/reindexer/v4/cjson" + "github.com/restream/reindexer/v5/bindings" + "github.com/restream/reindexer/v5/cjson" ) type bufPtr struct { diff --git a/bindings/cproto/connection.mock.go b/bindings/cproto/connection.mock.go index 6e5a6d1f5..66521a39f 100644 --- a/bindings/cproto/connection.mock.go +++ b/bindings/cproto/connection.mock.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/restream/reindexer/v4/bindings" + "github.com/restream/reindexer/v5/bindings" ) type MockConnection struct { diff --git a/bindings/cproto/cproto.go b/bindings/cproto/cproto.go index 162f432f1..fde63854b 100644 --- a/bindings/cproto/cproto.go +++ b/bindings/cproto/cproto.go @@ -15,8 +15,8 @@ import ( "sync/atomic" "time" - "github.com/restream/reindexer/v4/bindings" - "github.com/restream/reindexer/v4/cjson" + "github.com/restream/reindexer/v5/bindings" + "github.com/restream/reindexer/v5/cjson" ) const ( @@ -277,7 +277,8 @@ func (binding *NetCProto) Init(u []url.URL, eh bindings.EventsHandler, options . binding.caps = *bindings.DefaultBindingCapabilities(). WithQrIdleTimeouts(true). WithResultsWithShardIDs(true). - WithIncarnationTags(true) + WithIncarnationTags(true). + WithFloatRank(true) for _, option := range options { switch v := option.(type) { diff --git a/bindings/cproto/cproto_test.go b/bindings/cproto/cproto_test.go index 6c9bb153a..0086f0974 100644 --- a/bindings/cproto/cproto_test.go +++ b/bindings/cproto/cproto_test.go @@ -15,9 +15,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/restream/reindexer/v4" - "github.com/restream/reindexer/v4/bindings" - "github.com/restream/reindexer/v4/test/helpers" + "github.com/restream/reindexer/v5" + "github.com/restream/reindexer/v5/bindings" + "github.com/restream/reindexer/v5/test/helpers" ) var benchmarkSeed = flag.Int64("seed", time.Now().Unix(), "seed number for random") diff --git a/bindings/cproto/cproto_unit_test.go b/bindings/cproto/cproto_unit_test.go index 086e50f4a..41626f43f 100644 --- a/bindings/cproto/cproto_unit_test.go +++ b/bindings/cproto/cproto_unit_test.go @@ -10,7 +10,7 @@ import ( "net/url" "testing" - "github.com/restream/reindexer/v4/bindings" + "github.com/restream/reindexer/v5/bindings" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/bindings/cproto/encdec.go b/bindings/cproto/encdec.go index 92a019cfa..f5b345b35 100644 --- a/bindings/cproto/encdec.go +++ b/bindings/cproto/encdec.go @@ -4,8 +4,8 @@ import ( "fmt" "unsafe" - "github.com/restream/reindexer/v4/bindings" - "github.com/restream/reindexer/v4/cjson" + "github.com/restream/reindexer/v5/bindings" + "github.com/restream/reindexer/v5/cjson" "github.com/golang/snappy" ) diff --git a/bindings/cproto/netbuf.go b/bindings/cproto/netbuf.go index a4b6ccb3c..e95528e05 100644 --- a/bindings/cproto/netbuf.go +++ b/bindings/cproto/netbuf.go @@ -4,7 +4,7 @@ import ( "context" "sync" - "github.com/restream/reindexer/v4/bindings" + "github.com/restream/reindexer/v5/bindings" "github.com/golang/snappy" ) diff --git a/bindings/cproto/pool.go b/bindings/cproto/pool.go index 94f61220a..e6017d829 100644 --- a/bindings/cproto/pool.go +++ b/bindings/cproto/pool.go @@ -4,7 +4,7 @@ import ( "math/rand" "sync/atomic" - "github.com/restream/reindexer/v4/bindings" + "github.com/restream/reindexer/v5/bindings" ) type pool struct { diff --git a/bindings/interface.go b/bindings/interface.go index fcc029e28..0bff4c66b 100644 --- a/bindings/interface.go +++ b/bindings/interface.go @@ -6,10 +6,25 @@ import ( "net/url" "time" - "github.com/restream/reindexer/v4/bindings/builtinserver/config" - "github.com/restream/reindexer/v4/jsonschema" + "github.com/restream/reindexer/v5/bindings/builtinserver/config" + "github.com/restream/reindexer/v5/jsonschema" ) +const ( + MultithreadingMode_SingleThread = 0 + MultithreadingMode_MultithreadTransactions = 1 +) + +type FloatVectorIndexOpts struct { + Metric string `json:"metric"` + Dimension int `json:"dimension"` + M int `json:"m,omitempty"` + EfConstruction int `json:"ef_construction,omitempty"` + StartSize int `json:"start_size,omitempty"` + CentroidsCount int `json:"centroids_count,omitempty"` + MultithreadingMode int `json:"multithreading,omitempty"` +} + type IndexDef struct { Name string `json:"name"` JSONPaths []string `json:"json_paths"` @@ -147,6 +162,16 @@ func (bc *BindingCapabilities) WithIncarnationTags(value bool) *BindingCapabilit return bc } +// Enable float rank format +func (bc *BindingCapabilities) WithFloatRank(value bool) *BindingCapabilities { + if value { + bc.Value |= int64(BindingCapabilityComplexRank) + } else { + bc.Value &= ^int64(BindingCapabilityComplexRank) + } + return bc +} + // go interface to reindexer_c.h interface type RawBuffer interface { GetBuf() []byte diff --git a/cjson/creflect.go b/cjson/creflect.go index db4976319..75757f102 100644 --- a/cjson/creflect.go +++ b/cjson/creflect.go @@ -5,21 +5,29 @@ import ( "reflect" "unsafe" - "github.com/restream/reindexer/v4/bindings" + "github.com/restream/reindexer/v5/bindings" ) const ( - valueInt = bindings.ValueInt - valueBool = bindings.ValueBool - valueInt64 = bindings.ValueInt64 - valueDouble = bindings.ValueDouble - valueString = bindings.ValueString - valueUuid = bindings.ValueUuid + valueInt = bindings.ValueInt + valueBool = bindings.ValueBool + valueInt64 = bindings.ValueInt64 + valueDouble = bindings.ValueDouble + valueString = bindings.ValueString + valueUuid = bindings.ValueUuid + valueFloatVector = bindings.ValueFloatVector + valueFloat = bindings.ValueFloat +) + +const ( + floatVectorDimensionOffset = 48 + floatVectorPtrMask = (uint64(1) << floatVectorDimensionOffset) - uint64(1) ) // to avoid gcc toolchain requirement // types from C. Danger expectation about go struct packing is like C struct packing type Cdouble float64 +type Cfloat float32 type Cint int32 type Cuint uint32 type Cunsigned uint32 @@ -32,8 +40,7 @@ type ArrayHeader struct { } type PStringHeader struct { - cstr unsafe.Pointer - len Cint + len Cuint } type LStringHeader struct { @@ -42,11 +49,12 @@ type LStringHeader struct { } type payloadFieldType struct { - Type int - Name string - Offset uintptr - Size uintptr - IsArray bool + Type int + Name string + Offset uintptr + Size uintptr + FloatVectorDimension uint16 + IsArray bool } type payloadType struct { @@ -61,6 +69,11 @@ func (pt *payloadType) Read(ser *Serializer, skip bool) { for i := 0; i < fieldsCount; i++ { fields[i].Type = int(ser.GetVarUInt()) + if fields[i].Type == valueFloatVector { + fields[i].FloatVectorDimension = uint16(ser.GetVarUInt()) + } else { + fields[i].FloatVectorDimension = 0 + } fields[i].Name = ser.GetVString() fields[i].Offset = uintptr(ser.GetVarUInt()) fields[i].Size = uintptr(ser.GetVarUInt()) @@ -168,6 +181,13 @@ func (pl *payloadIface) getUuid(field, idx int) string { return createUuid(*(*[2]uint64)(p)) } +func (pl *payloadIface) getFloatVector(field, idx int) []float32 { + p := pl.ptr(field, idx, valueFloatVector) + cFloatVectorView := *(*uint64)(p) + dim := cFloatVectorView >> floatVectorDimensionOffset + return (*[1 << 30]float32)(unsafe.Pointer(uintptr(cFloatVectorView & floatVectorPtrMask)))[:dim:dim] +} + func (pl *payloadIface) getFloat64(field, idx int) float64 { p := pl.ptr(field, idx, valueDouble) return float64(*(*Cdouble)(p)) @@ -194,8 +214,10 @@ func (pl *payloadIface) getBytes(field, idx int) []byte { strHdr := (*LStringHeader)(unsafe.Pointer(ppstring)) return (*[1 << 30]byte)(unsafe.Pointer(&strHdr.data))[:strHdr.len:strHdr.len] case keySringType: - strHdr := (*PStringHeader)(unsafe.Pointer(ppstring + pl.t.PStringHdrOffset)) - return (*[1 << 30]byte)(strHdr.cstr)[:strHdr.len:strHdr.len] + hdrPtr := ppstring + pl.t.PStringHdrOffset + strHdr := (*PStringHeader)(unsafe.Pointer(hdrPtr)) + dataPtr := unsafe.Pointer(hdrPtr + unsafe.Sizeof(PStringHeader{})) + return (*[1 << 30]byte)(dataPtr)[:strHdr.len:strHdr.len] default: panic(fmt.Sprintf("Unknow string type in payload value: %d", psType)) } @@ -221,14 +243,16 @@ func (pl *payloadIface) getArrayLen(field int) int { func (pl *payloadIface) getValue(field int, idx int, v reflect.Value) { k := v.Type().Kind() - if k == reflect.Slice { - el := reflect.New(v.Type().Elem()).Elem() - extSlice := reflect.Append(v, el) - v.Set(extSlice) - v = v.Index(v.Len() - 1) - k = v.Type().Kind() - } else if k == reflect.Array { - panic(fmt.Errorf("can not put single indexed value into the fixed size array")) + if pl.t.Fields[field].Type != valueFloatVector { + if k == reflect.Slice { + el := reflect.New(v.Type().Elem()).Elem() + extSlice := reflect.Append(v, el) + v.Set(extSlice) + v = v.Index(v.Len() - 1) + k = v.Type().Kind() + } else if k == reflect.Array { + panic(fmt.Errorf("can not put single indexed value into the fixed size array")) + } } switch pl.t.Fields[field].Type { case valueBool: @@ -257,6 +281,27 @@ func (pl *payloadIface) getValue(field int, idx int, v reflect.Value) { v.SetString(pl.getString(field, idx)) case valueUuid: v.SetString(pl.getUuid(field, idx)) + case valueFloatVector: + vec := pl.getFloatVector(field, idx) + if k == reflect.Slice { + extLen := v.Len() + len(vec) + extSlice := reflect.MakeSlice(v.Type(), extLen, extLen) + reflect.Copy(extSlice, v) + offset := v.Len() + for i := 0; i < len(vec); i++ { + extSlice.Index(i + offset).SetFloat(float64(vec[i])) + } + v.Set(extSlice) + } else if k == reflect.Array { + if len(vec) != v.Len() { + panic(fmt.Errorf("can not put float vector of size '%d' into array of size '%d", len(vec), v.Len())) + } + for i := 0; i < len(vec); i++ { + v.Index(i).SetFloat(float64(vec[i])) + } + } else { + panic(fmt.Errorf("can not put float vector value into not array field")) + } default: panic(fmt.Errorf("unknown key value type '%d'", pl.t.Fields[field].Type)) } @@ -265,6 +310,8 @@ func (pl *payloadIface) getValue(field int, idx int, v reflect.Value) { func (pl *payloadIface) getArray(field int, startIdx int, cnt int, v reflect.Value) { if cnt == 0 { + slice := reflect.MakeSlice(v.Type(), 0, 0) + v.Set(slice) return } @@ -568,6 +615,71 @@ func (pl *payloadIface) getArray(field int, startIdx int, cnt int, v reflect.Val v.Set(slice) } } + case valueFloat: + pi := (*[1 << 27]Cfloat)(ptr)[:cnt:cnt] + if v.Kind() == reflect.Array { + if v.Len() < cnt { + panic(fmt.Errorf("can not set %d values to array of %d elements", cnt, v.Len())) + } + switch v.Type().Elem().Kind() { + case reflect.Float32: + for i = 0; i < cnt; i++ { + v.Index(i).SetFloat(float64(float32(pi[i]))) + } + case reflect.Float64: + for i = 0; i < cnt; i++ { + v.Index(i).SetFloat(float64(float32(pi[i]))) + } + default: + panic(fmt.Errorf("can not convert '[]%s' to '[]float32'", v.Type().Elem().Kind().String())) + } + } else { + switch a := v.Addr().Interface().(type) { + case *[]float64: + if len(*a) == 0 { + *a = make([]float64, cnt) + } else { + i = len(*a) + var tmp []float64 + tmp, *a = *a, make([]float64, len(*a)+cnt) + copy(*a, tmp) + } + for j := 0; j < cnt; i, j = i+1, j+1 { + (*a)[i] = float64(float32(pi[j])) + } + case *[]float32: + if len(*a) == 0 { + *a = make([]float32, cnt) + } else { + i = len(*a) + var tmp []float32 + tmp, *a = *a, make([]float32, len(*a)+cnt) + copy(*a, tmp) + } + for j := 0; j < cnt; i, j = i+1, j+1 { + (*a)[i] = float32(pi[j]) + } + default: + var slice reflect.Value + if v.Len() == 0 { + slice = reflect.MakeSlice(v.Type(), cnt, cnt) + } else { + i = v.Len() + slice = reflect.Append(v, reflect.MakeSlice(v.Type(), cnt, cnt)) + } + for j := 0; j < cnt; i, j = i+1, j+1 { + sv := slice.Index(i) + if sv.Type().Kind() == reflect.Ptr { + el := reflect.New(reflect.New(sv.Type().Elem()).Elem().Type()) + el.Elem().SetFloat(float64(float32(pi[j]))) + sv.Set(el) + } else { + sv.SetFloat(float64(float32(pi[j]))) + } + } + v.Set(slice) + } + } case valueBool: pb := (*[1 << 27]Cbool)(ptr)[:cnt:cnt] switch a := v.Addr().Interface().(type) { @@ -671,6 +783,8 @@ func (pl *payloadIface) getArray(field int, startIdx int, cnt int, v reflect.Val } v.Set(slice) } + case valueFloatVector: + panic(fmt.Errorf("array of float vector is not supported")) default: panic(fmt.Errorf("got C array with elements of unknown C type %d in field '%s' for go type '%s'", pl.t.Fields[field].Type, pl.t.Fields[field].Name, v.Type().Elem().Kind().String())) } @@ -692,6 +806,8 @@ func (pl *payloadIface) getIface(field int) interface{} { return pl.getString(field, 0) case valueUuid: return pl.getUuid(field, 0) + case valueFloatVector: + return pl.getFloatVector(field, 0) } } @@ -699,35 +815,39 @@ func (pl *payloadIface) getIface(field int) interface{} { switch pl.t.Fields[field].Type { case valueInt: - a := make([]int, l, l) + a := make([]int, l) for i := 0; i < l; i++ { a[i] = pl.getInt(field, i) } return a case valueInt64: - a := make([]int64, l, l) + a := make([]int64, l) for i := 0; i < l; i++ { a[i] = pl.getInt64(field, i) } return a case valueDouble: - a := make([]float64, l, l) + a := make([]float64, l) for i := 0; i < l; i++ { a[i] = pl.getFloat64(field, i) } return a + case valueFloat: + panic(fmt.Errorf("float32 can not be indexed")) case valueString: - a := make([]string, l, l) + a := make([]string, l) for i := 0; i < l; i++ { a[i] = pl.getString(field, i) } return a case valueUuid: - a := make([]string, l, l) + a := make([]string, l) for i := 0; i < l; i++ { a[i] = pl.getUuid(field, i) } return a + case valueFloatVector: + panic(fmt.Errorf("array of float vector is not supported")) } return nil diff --git a/cjson/ctag.go b/cjson/ctag.go index b10db744b..5808dda17 100644 --- a/cjson/ctag.go +++ b/cjson/ctag.go @@ -25,6 +25,7 @@ const ( TAG_OBJECT = 6 TAG_END = 7 TAG_UUID = 8 + TAG_FLOAT = 9 ) func (c ctag) Name() int16 { @@ -63,6 +64,8 @@ func tagTypeName(tagType int16) string { return "" case TAG_UUID: return "" + case TAG_FLOAT: + return "" default: return fmt.Sprintf("", tagType) } diff --git a/cjson/decoder.go b/cjson/decoder.go index 16245c9bc..0f55495f1 100644 --- a/cjson/decoder.go +++ b/cjson/decoder.go @@ -9,7 +9,7 @@ import ( "time" "unsafe" - "github.com/restream/reindexer/v4/bindings" + "github.com/restream/reindexer/v5/bindings" ) var ( @@ -119,6 +119,8 @@ func skipTag(rdser *Serializer, tagType int16) { rdser.GetVString() case TAG_UUID: rdser.GetUuid() + case TAG_FLOAT: + rdser.GetFloat32() default: panic(fmt.Errorf("can not skip tagType '%s'", tagTypeName(tagType))) } @@ -132,6 +134,8 @@ func asInt(rdser *Serializer, tagType int16) int64 { return rdser.GetVarInt() case TAG_DOUBLE: return int64(rdser.GetDouble()) + case TAG_FLOAT: + return int64(rdser.GetFloat32()) default: panic(fmt.Errorf("can not convert tagType '%s' to 'int'", tagTypeName(tagType))) } @@ -143,6 +147,8 @@ func asFloat(rdser *Serializer, tagType int16) float64 { return float64(rdser.GetVarInt()) case TAG_DOUBLE: return rdser.GetDouble() + case TAG_FLOAT: + return float64(rdser.GetFloat32()) default: panic(fmt.Errorf("can not convert tagType '%s' to 'float'", tagTypeName(tagType))) } @@ -181,6 +187,8 @@ func asIface(rdser *Serializer, tagType int16) interface{} { return nil case TAG_UUID: return rdser.GetUuid() + case TAG_FLOAT: + return rdser.GetFloat32() default: panic(fmt.Errorf("can not convert tagType '%s' to 'interface'", tagTypeName(tagType))) } @@ -297,6 +305,8 @@ func mkValue(ctagType int16) (v reflect.Value) { v = reflect.New(reflect.TypeOf(make(map[string]interface{}))).Elem() case TAG_ARRAY: v = reflect.New(ifaceSliceType).Elem() + case TAG_FLOAT: + v = reflect.New(reflect.TypeOf(float32(0.0))).Elem() default: panic(fmt.Errorf("invalid ctagType=%d", ctagType)) } @@ -662,17 +672,23 @@ func (dec *Decoder) decodeValue(pl *payloadIface, rdser *Serializer, v reflect.V if ctagField >= 0 { // get data from payload object cnt := &fieldsoutcnt[ctagField] + field := int(ctagField) switch ctagType { case TAG_ARRAY: count := int(rdser.GetVarUInt()) if k == reflect.Slice || k == reflect.Array || count != 0 { // Allows empty slice for any scalar type (using default value) - pl.getArray(int(ctagField), *cnt, count, v) - *cnt += count + if pl.t.Fields[field].Type == bindings.ValueFloatVector { + pl.getValue(field, *cnt, v) + (*cnt)++ + } else { + pl.getArray(field, *cnt, count, v) + *cnt += count + } } else { initialV.Set(reflect.Zero(initialV.Type())) // Set nil to scalar pointers, intialized with empty arrays } default: - pl.getValue(int(ctagField), *cnt, v) + pl.getValue(field, *cnt, v) (*cnt)++ } } else { diff --git a/cjson/encoder.go b/cjson/encoder.go index c491953dd..6a8fccf6d 100644 --- a/cjson/encoder.go +++ b/cjson/encoder.go @@ -609,7 +609,9 @@ func (enc *Encoder) encodeSlice(v reflect.Value, rdser *Serializer, f fieldInfo, case reflect.Int, reflect.Int16, reflect.Int64, reflect.Int8, reflect.Int32, reflect.Uint, reflect.Uint16, reflect.Uint64, reflect.Uint32: subTag = TAG_VARINT - case reflect.Float32, reflect.Float64: + case reflect.Float32: + subTag = TAG_FLOAT + case reflect.Float64: subTag = TAG_DOUBLE case reflect.String: if f.isUuid { @@ -675,7 +677,7 @@ func (enc *Encoder) encodeSlice(v reflect.Value, rdser *Serializer, f fieldInfo, case reflect.Float32: sl := (*[1 << 28]float32)(ptr)[:l:l] for _, v := range sl { - rdser.PutDouble(float64(v)) + rdser.PutFloat32(v) } case reflect.Float64: sl := (*[1 << 27]float64)(ptr)[:l:l] @@ -732,7 +734,11 @@ func (enc *Encoder) encodeSlice(v reflect.Value, rdser *Serializer, f fieldInfo, for i := 0; i < l; i++ { rdser.PutVarUInt(v.Index(i).Uint()) } - case reflect.Float32, reflect.Float64: + case reflect.Float32: + for i := 0; i < l; i++ { + rdser.PutFloat32(float32(v.Index(i).Float())) + } + case reflect.Float64: for i := 0; i < l; i++ { rdser.PutDouble(v.Index(i).Float()) } @@ -804,7 +810,16 @@ func (enc *Encoder) encodeValue(v reflect.Value, rdser *Serializer, f fieldInfo, rdser.PutCTag(mkctag(TAG_VARINT, f.ctagName, 0)) rdser.PutVarInt(int64(val)) } - case reflect.Float32, reflect.Float64: + case reflect.Float32: + val := v.Float() + if val != 0 || !f.isOmitEmpty { + // rdser.PutCTag(mkctag(TAG_FLOAT, f.ctagName, 0)) + // rdser.PutFloat32(float32(val)) + // FIXME: Encoding float64 for the some kind of compatibility. We should encode float32 here after full migration to v5 + rdser.PutCTag(mkctag(TAG_DOUBLE, f.ctagName, 0)) + rdser.PutDouble(val) + } + case reflect.Float64: val := v.Float() if val != 0 || !f.isOmitEmpty { rdser.PutCTag(mkctag(TAG_DOUBLE, f.ctagName, 0)) diff --git a/cjson/serializer.go b/cjson/serializer.go index a3b6577bb..398e116d1 100644 --- a/cjson/serializer.go +++ b/cjson/serializer.go @@ -94,6 +94,19 @@ func (s *Serializer) PutUuid(v [2]uint64) *Serializer { return s } +func (s *Serializer) PutFloatVector(vec []float32) *Serializer { + s.PutVarUInt(uint64(len(vec)) << 1) + for _, value := range vec { + s.writeIntBits(int64(math.Float32bits(value)), unsafe.Sizeof(value)) + } + return s +} + +func (s *Serializer) PutFloat32(v float32) *Serializer { + s.writeIntBits(int64(math.Float32bits(v)), unsafe.Sizeof(v)) + return s +} + func (s *Serializer) PutDouble(v float64) *Serializer { s.writeIntBits(int64(math.Float64bits(v)), unsafe.Sizeof(v)) return s @@ -246,6 +259,10 @@ func (s *Serializer) GetDouble() (v float64) { return math.Float64frombits(uint64(s.readIntBits(unsafe.Sizeof(v)))) } +func (s *Serializer) GetFloat32() (v float32) { + return math.Float32frombits(uint32(s.readIntBits(unsafe.Sizeof(v)))) +} + func (s *Serializer) GetBytes() (v []byte) { l := int(s.GetUInt32()) if s.pos+l > len(s.buf) { diff --git a/clang-tidy/.clang-tidy-ignore b/clang-tidy/.clang-tidy-ignore index e21099090..a085cefb2 100644 --- a/clang-tidy/.clang-tidy-ignore +++ b/clang-tidy/.clang-tidy-ignore @@ -8,3 +8,5 @@ */murmurhash/MurmurHash3.cc # Backtrace library */libbacktrace/* +# Faiss (TODO: probably should remove this line before merge) +*/faiss/* diff --git a/cpp_src/CMakeLists.txt b/cpp_src/CMakeLists.txt index f7e501382..cdca6774e 100644 --- a/cpp_src/CMakeLists.txt +++ b/cpp_src/CMakeLists.txt @@ -1,13 +1,7 @@ -cmake_minimum_required(VERSION 3.10) +cmake_minimum_required(VERSION 3.18) # Configure cmake options -if(MSVC) - # Enable C++20 for windows build to be able to use designated initializers. - # GCC/Clang support them even with C++17. - set(CMAKE_CXX_STANDARD 20) -else() - set(CMAKE_CXX_STANDARD 17) -endif() +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(CMakeToolsHelpers OPTIONAL) include(ExternalProject) @@ -33,6 +27,9 @@ option(ENABLE_GRPC "Enable GRPC service" OFF) option(ENABLE_SSE "Enable SSE instructions" ON) option(ENABLE_SERVER_AS_PROCESS_IN_TEST "Run reindexer servers as separate processes in tests" OFF) option(ENABLE_V3_FOLLOWERS "Enable compatibility mode with reindexer v3 followers. This is temporary flag and will be removed in further releases" OFF) +option(ENABLE_PCH "Enable precompiled headers for the build" OFF) +set(BUILD_ANN_INDEXES "all" CACHE STRING "Enable ANN indexes build: none, builtin, all. 'All' builds both builtin and FAISS-based indexes") +set_property(CACHE BUILD_ANN_INDEXES PROPERTY STRINGS none builtin all) if(APPLE) option(ENABLE_OPENSSL "Enable OpenSSL" OFF) @@ -40,6 +37,19 @@ else() option(ENABLE_OPENSSL "Enable OpenSSL" ON) endif() +if(APPLE AND NOT DEFINED ENV{OpenMP_ROOT}) + execute_process(COMMAND brew --prefix OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE BREW_PRFIX) + message("OpenMP_ROOT is not defined. Setting OMP root to: ${BREW_PRFIX}/opt/libomp") + set(ENV{OpenMP_ROOT} "${BREW_PRFIX}/opt/libomp") +endif() + +if(ENABLE_PCH) + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang") + message("Disabling PCH for Clang - there are some compatibility problems") + set(ENABLE_PCH OFF CACHE INTERNAL "" FORCE) + endif() +endif() + if(NOT GRPC_PACKAGE_PROVIDER) set(GRPC_PACKAGE_PROVIDER "CONFIG") endif() @@ -72,8 +82,8 @@ if(MSVC) set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O2 -Zi") set(CMAKE_CXX_FLAGS_RELEASE "-O2 -DNDEBUG -Zi") set(CMAKE_C_FLAGS_RELEASE "-O2 -DNDEBUG -Zi") -elseif(WITH_ASAN) - # Using O2 instead of O3 to build a bit faster. +elseif(WITH_ASAN OR WITH_STDLIB_DEBUG) + # Using O2 instead of O3 to build a bit faster # Also this allows to avoid SEGFAULT in libasan.so during coroutines interaction on CentOS7 (gcc-12). set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g1") set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O2 -g1") @@ -103,8 +113,8 @@ if(MSVC) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -wd4244 -wd4267 -wd4996 -wd4717 -wd4800 -wd4396 -wd4503 -MP -MD /bigobj") set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -SAFESEH:NO") else() - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra -Werror -Wswitch-enum") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -Wall -Wextra -Werror -Wswitch-enum -Wold-style-cast -fexceptions") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Wextra -Werror -Wswitch-enum -Winvalid-pch") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++20 -Wall -Wextra -Werror -Wswitch-enum -Wold-style-cast -Winvalid-pch -fexceptions") if(${COMPILER_TARGET_ARCH} STREQUAL "e2k") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -gline -fverbose-asm") set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-parameter") @@ -156,8 +166,8 @@ set(REINDEXER_SOURCE_PATH ${PROJECT_SOURCE_DIR}) set(REINDEXER_BINARY_PATH ${PROJECT_BINARY_DIR}) file ( - GLOB_RECURSE - SRCS + GLOB_RECURSE + SRCS ${REINDEXER_SOURCE_PATH}/client/* ${REINDEXER_SOURCE_PATH}/core/* ${REINDEXER_SOURCE_PATH}/estl/* @@ -270,6 +280,18 @@ if(ENABLE_SSE) message("Building with SSE support...") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse -msse2 -msse3 -mssse3 -msse4 -msse4.1 -msse4.2 -mpopcnt") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse -msse2 -msse3 -mssse3 -msse4 -msse4.1 -msse4.2 -mpopcnt") + elseif(MSVC AND (${COMPILER_TARGET_ARCH} STREQUAL "x86_64" OR ${COMPILER_TARGET_ARCH} STREQUAL "i386")) + add_definitions(-DREINDEXER_WITH_SSE=1) + # MSVC does not define __SSE__ by itself + # https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macros + add_definitions(-D__SSE__=1) + add_definitions(-D__SSE2__=1) + add_definitions(-D__SSE3__=1) + add_definitions(-D__SSE4_1__=1) + add_definitions(-D__SSE4_2__=1) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:SSE4.2") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:SSE4.2") + message("Building with SSE support...") else() message("SSE compiler flags were disabled for the current platform") endif() @@ -277,6 +299,7 @@ endif() include_directories(${REINDEXER_SOURCE_PATH}) include_directories(${REINDEXER_SOURCE_PATH}/vendor) +include_directories(${REINDEXER_SOURCE_PATH}/vendor_subdirs) set(MSGPACK_INCLUDE_PATH ${REINDEXER_SOURCE_PATH}/vendor/msgpack) include_directories(${MSGPACK_INCLUDE_PATH}) @@ -296,6 +319,7 @@ endif() add_definitions(-DBUILDING_KOISHI) add_definitions(-DYAML_CPP_STATIC_DEFINE) include_directories(${KOISHI_PATH}/include) +add_definitions(-DFINTEGER=int) list(APPEND SRCS ${KOISHI_PATH}/include/koishi.h ${KOISHI_PATH}/fiber.h @@ -312,11 +336,71 @@ endif() if(MSVC) set_source_files_properties (${REINDEXER_SOURCE_PATH}/core/storage/leveldblogger.cc PROPERTIES COMPILE_FLAGS "/GR-") else() - set_source_files_properties (${REINDEXER_SOURCE_PATH}/core/storage/leveldblogger.cc PROPERTIES COMPILE_FLAGS "-fno-rtti") + set_source_files_properties (${REINDEXER_SOURCE_PATH}/core/storage/leveldblogger.cc PROPERTIES COMPILE_FLAGS "-fno-rtti -Wno-invalid-pch") endif() list(APPEND REINDEXER_LIBRARIES reindexer) -add_library(${TARGET} STATIC ${HDRS} ${SRCS} ${VENDORS}) +add_library(${TARGET}_obj OBJECT ${SRCS}) + +if(ENABLE_PCH) + set(PREC_HDRS + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + $<$:> + + $<$:estl/h_vector.h> + $<$:estl/fast_hash_set.h> + $<$:estl/fast_hash_map.h> + $<$:estl/intrusive_ptr.h> + $<$:estl/mutex.h> + $<$:estl/one_of.h> + $<$:estl/shared_mutex.h> + $<$:estl/smart_lock.h> + $<$:estl/tokenizer.h> + $<$:estl/contexted_locks.h> + $<$:estl/cow.h> + $<$:tools/errors.h> + $<$:tools/logger.h> + $<$:tools/lsn.h> + $<$:tools/serializer.h> + $<$:tools/stringstools.h> + $<$:tools/varint.h> + $<$:vendor/sort/pdqsort.hpp> + + $<$:updates/updaterecord.h> + $<$:updates/updatesqueue.h> + + $<$:cluster/config.h> + + $<$:core/expressiontree.h> + $<$:core/keyvalue/variant.h> + $<$:core/keyvalue/p_string.h> + $<$:core/keyvalue/key_string.h> + $<$:core/payload/payloadiface.h> + $<$:core/rdxcontext.h> + $<$:core/ft/usingcontainer.h> + $<$:vendor/sparse-map/sparse_map.h> + $<$:vendor/sparse-map/sparse_set.h> + + $<$:net/ev/ev.h> + ) + target_precompile_headers(${TARGET}_obj PRIVATE ${PREC_HDRS}) +endif() + add_definitions(-DREINDEX_CORE_BUILD=1) add_definitions(-DFMT_HEADER_ONLY=1) add_definitions(-DSPDLOG_FMT_EXTERNAL=1) @@ -381,6 +465,32 @@ else() list(APPEND REINDEXER_LIBRARIES snappy) endif() +set(OBJ_LIBRARIES ${TARGET}_obj) +if(BUILD_ANN_INDEXES STREQUAL "all") + message("Building with full ANN-indexes support...") + find_package(OpenMP REQUIRED) + include_directories(SYSTEM ${OpenMP_CXX_INCLUDE_DIR}) + list(APPEND REINDEXER_LIBRARIES ${OpenMP_CXX_LIBRARIES}) + if(NOT APPLE AND NOT MSVC) + set(EXTRA_FLAGS "${EXTRA_FLAGS} -fopenmp") + list(APPEND REINDEXER_LIBRARIES ${OpenMP_CXX_LIBRARIES} gomp) + endif() + + set(BUILD_FAISS_GLOBAL ON) + add_definitions(-DRX_WITH_FAISS_ANN_INDEXES=1) + add_definitions(-DRX_WITH_BUILTIN_ANN_INDEXES=1) + add_subdirectory(vendor_subdirs/faiss) + list(APPEND OBJ_LIBRARIES faiss_obj) +elseif(BUILD_ANN_INDEXES STREQUAL "builtin") + message("Building with builtin ANN-indexes only...") + add_definitions(-DRX_WITH_BUILTIN_ANN_INDEXES=1) +else() + message("Building without ANN-indexes...") +endif() + +add_library(${TARGET}) +target_link_libraries(${TARGET} PUBLIC ${OBJ_LIBRARIES}) + # storage ######### # rocksdb @@ -473,7 +583,7 @@ if(NOT LevelDB_LIBRARY OR NOT LevelDB_INCLUDE_DIR OR WITH_TSAN) link_directories(${CMAKE_CURRENT_BINARY_DIR}) list(APPEND REINDEXER_LINK_DIRECTORIES ${CMAKE_CURRENT_BINARY_DIR}) list(INSERT REINDEXER_LIBRARIES 1 leveldb) - add_dependencies(reindexer leveldb_lib) + add_dependencies(${TARGET}_obj leveldb_lib) else() message(STATUS "Found LevelDB: ${LevelDB_LIBRARY}") include_directories(SYSTEM ${LevelDB_INCLUDE_DIR}) @@ -498,7 +608,7 @@ if(WITH_CPPTRACE) cpptrace_lib GIT_REPOSITORY "https://github.com/jeremy-rifkin/cpptrace.git" GIT_TAG "v0.3.1" - CMAKE_ARGS -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} + CMAKE_ARGS -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} -DCMAKE_INSTALL_PREFIX=${CMAKE_CURRENT_BINARY_DIR} -DCPPTRACE_BUILD_SHARED=Off -DCPPTRACE_GET_SYMBOLS_WITH_DBGHELP=On @@ -506,7 +616,7 @@ if(WITH_CPPTRACE) -DCPPTRACE_DEMANGLE_WITH_WINAPI=On ) add_definitions(-DREINDEX_WITH_CPPTRACE) - add_dependencies(reindexer cpptrace_lib) + add_dependencies(${TARGET}_obj cpptrace_lib) list(APPEND REINDEXER_LIBRARIES cpptrace ${REINDEXER_LIBRARIES}) endif() @@ -518,11 +628,9 @@ endif() if(NOT WIN32) # libdl - find_library(LIBDL dl) - if(LIBDL) - list(APPEND REINDEXER_LIBRARIES ${LIBDL}) - add_definitions(-DREINDEX_WITH_LIBDL=1) - endif() + find_library(LIBDL dl REQUIRED) + list(APPEND REINDEXER_LIBRARIES ${LIBDL}) + add_definitions(-DREINDEX_WITH_LIBDL=1) endif() # Unwind from libgcc or clang @@ -534,9 +642,9 @@ list(REMOVE_ITEM CMAKE_REQUIRED_DEFINITIONS -D_GNU_SOURCE) if(HAVE_BACKTRACE AND HAVE_GETIPINFO) set(SYSUNWIND On) message("-- Found system unwind") - add_definitions(-DREINDEX_WITH_UNWIND=1) + add_definitions(-DREINDEX_WITH_UNWIND=1) endif() - + # libunwind if(ENABLE_LIBUNWIND) find_library(LIBUNWIND unwind) diff --git a/cpp_src/client/coroqueryresults.cc b/cpp_src/client/coroqueryresults.cc index b1273e35d..6af03df6e 100644 --- a/cpp_src/client/coroqueryresults.cc +++ b/cpp_src/client/coroqueryresults.cc @@ -75,8 +75,8 @@ void CoroQueryResults::Bind(std::string_view rawResult, RPCQrId id, const Query* ser.GetRawQueryParams( i_.queryParams_, [&ser, this](int nsIdx) { - const uint32_t stateToken = ser.GetVarUint(); - const int version = ser.GetVarUint(); + const uint32_t stateToken = ser.GetVarUInt(); + const int version = ser.GetVarUInt(); TagsMatcher newTm; newTm.deserialize(ser, version, stateToken); i_.nsArray_[nsIdx]->TryReplaceTagsMatcher(std::move(newTm)); @@ -276,7 +276,7 @@ void CoroQueryResults::Iterator::getJSONFromCJSON(std::string_view cjson, WrSeri } if (qr_->HaveJoined() && joinedData_.size()) { EncoderDatasourceWithJoins joinsDs(joinedData_, *qr_); - AdditionalDatasource ds = qr_->NeedOutputRank() ? AdditionalDatasource(itemParams_.proc, &joinsDs) : AdditionalDatasource(&joinsDs); + AdditionalDatasource ds = qr_->NeedOutputRank() ? AdditionalDatasource(itemParams_.rank, &joinsDs) : AdditionalDatasource(&joinsDs); dss.push_back(&ds); if (withHdrLen) { auto slicePosSaver = wrser.StartSlice(); @@ -285,7 +285,7 @@ void CoroQueryResults::Iterator::getJSONFromCJSON(std::string_view cjson, WrSeri return; } - AdditionalDatasource ds(itemParams_.proc, nullptr); + AdditionalDatasource ds(itemParams_.rank, nullptr); AdditionalDatasource* dspPtr = qr_->NeedOutputRank() ? &ds : nullptr; if (dspPtr) { dss.push_back(dspPtr); @@ -487,11 +487,13 @@ int CoroQueryResults::Iterator::GetShardID() { return ShardingKeyType::ProxyOff; } -int16_t CoroQueryResults::Iterator::GetRank() { +RankT CoroQueryResults::Iterator::GetRank() { readNext(); - return itemParams_.proc; + return itemParams_.rank; } +bool CoroQueryResults::Iterator::IsRanked() { return qr_->HaveRank(); } + bool CoroQueryResults::Iterator::IsRaw() { readNext(); return itemParams_.raw; @@ -529,9 +531,9 @@ void CoroQueryResults::Iterator::readNext() { int format = qr_->i_.queryParams_.flags & kResultsFormatMask; (void)format; assert(format == kResultsCJson); - int joinedFields = ser.GetVarUint(); + int joinedFields = ser.GetVarUInt(); for (int i = 0; i < joinedFields; ++i) { - int itemsCount = ser.GetVarUint(); + int itemsCount = ser.GetVarUInt(); h_vector joined; joined.reserve(itemsCount); for (int j = 0; j < itemsCount; ++j) { diff --git a/cpp_src/client/coroqueryresults.h b/cpp_src/client/coroqueryresults.h index 4d363fa0e..c69e68b39 100644 --- a/cpp_src/client/coroqueryresults.h +++ b/cpp_src/client/coroqueryresults.h @@ -5,6 +5,7 @@ #include "client/item.h" #include "client/resultserializer.h" #include "core/namespace/incarnationtags.h" +#include "core/rank_t.h" #include "tools/clock.h" #include "tools/lsn.h" @@ -68,7 +69,8 @@ class CoroQueryResults { int GetNSID(); int GetID(); int GetShardID(); - int16_t GetRank(); + RankT GetRank(); + bool IsRanked(); bool IsRaw(); std::string_view GetRaw(); const JoinedData& GetJoined(); diff --git a/cpp_src/client/cororeindexer.h b/cpp_src/client/cororeindexer.h index 609113cc8..14589fcdc 100644 --- a/cpp_src/client/cororeindexer.h +++ b/cpp_src/client/cororeindexer.h @@ -1,5 +1,6 @@ #pragma once +#include "client/connectopts.h" #include "client/coroqueryresults.h" #include "client/corotransaction.h" #include "client/internalrdxcontext.h" diff --git a/cpp_src/client/corotransaction.cc b/cpp_src/client/corotransaction.cc index a71891b71..feb1e38bd 100644 --- a/cpp_src/client/corotransaction.cc +++ b/cpp_src/client/corotransaction.cc @@ -66,7 +66,7 @@ Error CoroTransaction::Modify(Query&& query, lsn_t lsn) { case QuerySelect: case QueryTruncate: default: - return Error(errParams, "Incorrect query type in transaction modify %d", query.type_); + return Error(errParams, "Incorrect query type in transaction modify %d", int(query.type_)); } } diff --git a/cpp_src/client/itemimpl.h b/cpp_src/client/itemimpl.h index f12a28f56..e63bcde62 100644 --- a/cpp_src/client/itemimpl.h +++ b/cpp_src/client/itemimpl.h @@ -1,4 +1,5 @@ #pragma once + #include "client/itemimplbase.h" namespace reindexer { diff --git a/cpp_src/client/itemimplbase.cc b/cpp_src/client/itemimplbase.cc index d22b2e325..71e2f3c21 100644 --- a/cpp_src/client/itemimplbase.cc +++ b/cpp_src/client/itemimplbase.cc @@ -1,9 +1,10 @@ -#include "itemimplbase.h" +#include "client/itemimplbase.h" #include "core/cjson/baseencoder.h" #include "core/cjson/cjsondecoder.h" #include "core/cjson/jsondecoder.h" #include "core/cjson/msgpackbuilder.h" #include "core/cjson/msgpackdecoder.h" +#include "estl/gift_str.h" namespace reindexer { namespace client { @@ -34,7 +35,7 @@ void ItemImplBase::FromCJSON(std::string_view slice) { CJsonDecoder decoder(tagsMatcher_, holder_); ser_.Reset(); try { - decoder.Decode(pl, rdser, ser_); + decoder.Decode(pl, rdser, ser_, floatVectorsHolder_); } catch (const Error& e) { if (!hasBundledTm) { const auto err = tryToUpdateTagsMatcher(); @@ -44,7 +45,7 @@ void ItemImplBase::FromCJSON(std::string_view slice) { ser_.Reset(); rdser.SetPos(0); CJsonDecoder decoder(tagsMatcher_, holder_); - decoder.Decode(pl, rdser, ser_); + decoder.Decode(pl, rdser, ser_, floatVectorsHolder_); } } @@ -55,7 +56,7 @@ void ItemImplBase::FromCJSON(std::string_view slice) { const auto tupleSize = ser_.Len(); tupleHolder_ = ser_.DetachBuf(); tupleData_ = std::string_view(reinterpret_cast(tupleHolder_.get()), tupleSize); - pl.Set(0, Variant(p_string(&tupleData_), Variant::no_hold_t{})); + pl.Set(0, Variant{p_string(&tupleData_), Variant::noHold}); } Error ItemImplBase::FromJSON(std::string_view slice, char** endp, bool /*pkOnly*/) { @@ -72,7 +73,7 @@ Error ItemImplBase::FromJSON(std::string_view slice, char** endp, bool /*pkOnly* gason::JsonParser parser(&largeJSONStrings_); try { node = parser.Parse(giftStr(data), &len); - if (node.value.getTag() != gason::JSON_OBJECT) { + if (node.value.getTag() != gason::JsonTag::OBJECT) { return Error(errParseJson, "Expected json object"); } if (unsafe_ && endp) { @@ -86,14 +87,14 @@ Error ItemImplBase::FromJSON(std::string_view slice, char** endp, bool /*pkOnly* JsonDecoder decoder(tagsMatcher_); Payload pl = GetPayload(); ser_.Reset(); - auto err = decoder.Decode(pl, ser_, node.value); + auto err = decoder.Decode(pl, ser_, node.value, floatVectorsHolder_); if (err.ok()) { // Put tuple to field[0] const auto tupleSize = ser_.Len(); tupleHolder_ = ser_.DetachBuf(); tupleData_ = std::string_view(reinterpret_cast(tupleHolder_.get()), tupleSize); - pl.Set(0, Variant(p_string(&tupleData_), Variant::no_hold_t{})); + pl.Set(0, Variant(p_string(&tupleData_), Variant::noHold)); } return err; } @@ -109,12 +110,12 @@ Error ItemImplBase::FromMsgPack(std::string_view buf, size_t& offset) { } ser_.Reset(); - Error err = decoder.Decode(data, pl, ser_, offset); + Error err = decoder.Decode(data, pl, ser_, offset, floatVectorsHolder_); if (err.ok()) { const auto tupleSize = ser_.Len(); tupleHolder_ = ser_.DetachBuf(); tupleData_ = std::string_view(reinterpret_cast(tupleHolder_.get()), tupleSize); - pl.Set(0, Variant(p_string(&tupleData_), Variant::no_hold_t{})); + pl.Set(0, Variant(p_string(&tupleData_), Variant::noHold)); } return err; } diff --git a/cpp_src/client/itemimplbase.h b/cpp_src/client/itemimplbase.h index ac3d6acfc..44e1ae269 100644 --- a/cpp_src/client/itemimplbase.h +++ b/cpp_src/client/itemimplbase.h @@ -1,6 +1,7 @@ #pragma once #include +#include "core/keyvalue/float_vectors_holder.h" #include "core/keyvalue/variant.h" #include "core/payload/payloadiface.h" #include "core/query/query.h" @@ -62,13 +63,14 @@ class ItemImplBase { static bool ReadBundledTmTag(Serializer& ser) { return ser.GetCTag() == kCTagEnd; } protected: - virtual Error tryToUpdateTagsMatcher() = 0; - // Index fields payload data PayloadType payloadType_; PayloadValue payloadValue_; TagsMatcher tagsMatcher_; +private: + virtual Error tryToUpdateTagsMatcher() = 0; + WrSerializer ser_; std::string_view tupleData_; std::unique_ptr tupleHolder_; @@ -77,6 +79,7 @@ class ItemImplBase { bool unsafe_ = false; h_vector holder_; std::vector> largeJSONStrings_; + FloatVectorsHolderVector floatVectorsHolder_; }; } // namespace client diff --git a/cpp_src/client/raftclient.h b/cpp_src/client/raftclient.h index d5c4e7e2f..d555536fa 100644 --- a/cpp_src/client/raftclient.h +++ b/cpp_src/client/raftclient.h @@ -1,5 +1,6 @@ #pragma once +#include "client/connectopts.h" #include "client/internalrdxcontext.h" #include "client/reindexerconfig.h" #include "client/rpcclient.h" diff --git a/cpp_src/client/reindexer.h b/cpp_src/client/reindexer.h index 383beb947..e370ee78d 100644 --- a/cpp_src/client/reindexer.h +++ b/cpp_src/client/reindexer.h @@ -1,5 +1,6 @@ #pragma once +#include "client/connectopts.h" #include "client/queryresults.h" #include "client/reindexerconfig.h" #include "client/transaction.h" diff --git a/cpp_src/client/reindexerconfig.h b/cpp_src/client/reindexerconfig.h index f798b59f1..a6ab901a8 100644 --- a/cpp_src/client/reindexerconfig.h +++ b/cpp_src/client/reindexerconfig.h @@ -2,7 +2,6 @@ #include #include -#include "connectopts.h" namespace reindexer { namespace client { diff --git a/cpp_src/client/reindexerimpl.cc b/cpp_src/client/reindexerimpl.cc index c60872e2c..078d03acb 100644 --- a/cpp_src/client/reindexerimpl.cc +++ b/cpp_src/client/reindexerimpl.cc @@ -707,7 +707,7 @@ void ReindexerImpl::coroInterpreter(Connection& conn, Connectio break; case QuerySelect: case QueryTruncate: - err = Error(errParams, "Incorrect query type in transaction modify %d", std::get<1>(cd->arguments).type_); + err = Error(errParams, "Incorrect query type in transaction modify %d", int(std::get<1>(cd->arguments).type_)); } } if (cd->ctx.cmpl()) { diff --git a/cpp_src/client/reindexerimpl.h b/cpp_src/client/reindexerimpl.h index 92f0e4e15..f51273a93 100644 --- a/cpp_src/client/reindexerimpl.h +++ b/cpp_src/client/reindexerimpl.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include #include diff --git a/cpp_src/client/resultserializer.cc b/cpp_src/client/resultserializer.cc index 883b6a8b3..effe1e98e 100644 --- a/cpp_src/client/resultserializer.cc +++ b/cpp_src/client/resultserializer.cc @@ -1,15 +1,16 @@ #include "resultserializer.h" #include "core/payload/payloadtypeimpl.h" +#include "estl/gift_str.h" namespace reindexer { namespace client { void ResultSerializer::GetRawQueryParams(ResultSerializer::QueryParams& ret, const std::function& updatePayloadFunc, Options opts, ParsingData& parsingData) { - ret.flags = GetVarUint(); - ret.totalcount = GetVarUint(); - ret.qcount = GetVarUint(); - ret.count = GetVarUint(); + ret.flags = GetVarUInt(); + ret.totalcount = GetVarUInt(); + ret.qcount = GetVarUInt(); + ret.count = GetVarUInt(); ret.nsIncarnationTags.clear(); ret.shardingConfigVersion = ShardingSourceId::NotSet; if (opts.IsWithClearAggs()) { @@ -25,10 +26,10 @@ void ResultSerializer::GetRawQueryParams(ResultSerializer::QueryParams& ret, con parsingData.pts.begin = Pos(); if (ret.flags & kResultsWithPayloadTypes) { - int ptCount = GetVarUint(); + int ptCount = GetVarUInt(); for (int i = 0; i < ptCount; ++i) { - int nsid = GetVarUint(); + int nsid = GetVarUInt(); GetVString(); assertrx(updatePayloadFunc != nullptr); @@ -52,7 +53,7 @@ void ResultSerializer::GetExtraParams(ResultSerializer::QueryParams& ret, Option } bool firstLazyData = true; for (;;) { - const int tag = GetVarUint(); + const int tag = GetVarUInt(); switch (tag) { case QueryResultEnd: return; @@ -64,16 +65,12 @@ void ResultSerializer::GetExtraParams(ResultSerializer::QueryParams& ret, Option ret.aggResults.emplace(); ret.explainResults.emplace(); } - // firstLazyData guaranties, that aggResults will be non-'nullopt' - auto& aggRes = ret.aggResults->emplace_back(); // NOLINT(bugprone-unchecked-optional-access) - Error err; - if ((ret.flags & kResultsFormatMask) == kResultsMsgPack) { - err = aggRes.FromMsgPack(data); + if (auto aggRes = ((ret.flags & kResultsFormatMask) == kResultsMsgPack) ? AggregationResult::FromMsgPack(data) + : AggregationResult::FromJSON(giftStr(data))) { + // firstLazyData guaranties, that aggResults will be non-'nullopt' + ret.aggResults->emplace_back(std::move(*aggRes)); // NOLINT(bugprone-unchecked-optional-access) } else { - err = aggRes.FromJSON(giftStr(data)); - } - if (!err.ok()) { - throw err; + throw aggRes.error(); } } break; @@ -91,21 +88,21 @@ void ResultSerializer::GetExtraParams(ResultSerializer::QueryParams& ret, Option break; } case QueryResultShardingVersion: { - ret.shardingConfigVersion = GetVarUint(); + ret.shardingConfigVersion = GetVarUInt(); break; } case QueryResultShardId: { - ret.shardId = GetVarUint(); + ret.shardId = GetVarUInt(); break; } case QueryResultIncarnationTags: { - const auto size = GetVarUint(); + const auto size = GetVarUInt(); if (size) { ret.nsIncarnationTags.reserve(size); for (size_t i = 0; i < size; ++i) { auto& shardTags = ret.nsIncarnationTags.emplace_back(); shardTags.shardId = GetVarint(); - const auto tagsSize = GetVarUint(); + const auto tagsSize = GetVarUInt(); shardTags.tags.reserve(tagsSize); for (size_t j = 0; j < tagsSize; ++j) { shardTags.tags.emplace_back(GetVarint()); @@ -114,6 +111,13 @@ void ResultSerializer::GetExtraParams(ResultSerializer::QueryParams& ret, Option } break; } + case QueryResultRankFormat: { + const auto format = GetVarUInt(); + if (format != RankFormat::SingleFloatValue) { + throw Error(errLogic, "Unexpected rank format value tag: %d - only supported format is 0 (single float rank)", format); + } + break; + } default: throw Error(errLogic, "Unexpected Query tag: %d", tag); } @@ -124,15 +128,15 @@ ResultSerializer::ItemParams ResultSerializer::GetItemData(int flags, int shardI ItemParams ret; if (flags & kResultsWithItemID) { - ret.id = int(GetVarUint()); - ret.lsn = lsn_t(GetVarUint()); + ret.id = int(GetVarUInt()); + ret.lsn = lsn_t(GetVarUInt()); } if (flags & kResultsWithNsID) { - ret.nsid = int(GetVarUint()); + ret.nsid = int(GetVarUInt()); } if (flags & kResultsWithRank) { - ret.proc = int(GetVarUint()); + ret.rank = GetRank(); } if (flags & kResultsWithRaw) { @@ -143,7 +147,7 @@ ResultSerializer::ItemParams ResultSerializer::GetItemData(int flags, int shardI if (shardId != ShardingKeyType::ProxyOff) { ret.shardId = shardId; } else { - ret.shardId = int(GetVarUint()); + ret.shardId = int(GetVarUInt()); } } switch (flags & kResultsFormatMask) { diff --git a/cpp_src/client/resultserializer.h b/cpp_src/client/resultserializer.h index 84d6faf10..a2e34a634 100644 --- a/cpp_src/client/resultserializer.h +++ b/cpp_src/client/resultserializer.h @@ -3,6 +3,7 @@ #include #include "core/namespace/incarnationtags.h" #include "core/queryresults/aggregationresult.h" +#include "core/rank_t.h" #include "tools/lsn.h" #include "tools/serializer.h" @@ -27,9 +28,9 @@ class ResultSerializer : public Serializer { }; struct ItemParams { - int id = -1; - int16_t nsid = 0; - int16_t proc = 0; + IdType id = -1; + uint16_t nsid = 0; + RankT rank = 0.0; lsn_t lsn; std::string_view data; bool raw = false; @@ -61,7 +62,7 @@ class ResultSerializer : public Serializer { bool ContainsPayloads() const { Serializer ser(Buf(), Len()); - return ser.GetVarUint() & kResultsWithPayloadTypes; + return ser.GetVarUInt() & kResultsWithPayloadTypes; } void GetRawQueryParams(QueryParams& ret, const std::function& updatePayloadFunc, Options options, ParsingData& parsingData); diff --git a/cpp_src/client/rpcclient.cc b/cpp_src/client/rpcclient.cc index df083006e..9e23dc68a 100644 --- a/cpp_src/client/rpcclient.cc +++ b/cpp_src/client/rpcclient.cc @@ -1,5 +1,5 @@ #include "client/rpcclient.h" -#include +#include "client/connectopts.h" #include "client/itemimplbase.h" #include "client/snapshot.h" #include "cluster/clustercontrolrequest.h" @@ -7,6 +7,7 @@ #include "core/namespace/namespacestat.h" #include "core/namespacedef.h" #include "core/schema.h" +#include "estl/gift_str.h" #include "gason/gason.h" #include "tools/catch_and_return.h" #include "tools/cpucheck.h" @@ -15,8 +16,6 @@ namespace reindexer { namespace client { -using reindexer::net::cproto::CoroRPCAnswer; - RPCClient::RPCClient(const ReindexerConfig& config, INamespaces::PtrT sharedNamespaces) : namespaces_(sharedNamespaces ? std::move(sharedNamespaces) : INamespaces::PtrT(new NamespacesImpl())), config_(config) { reindexer::CheckRequiredSSESupport(); @@ -299,8 +298,8 @@ Error RPCClient::modifyItemRaw(std::string_view nsName, std::string_view cjson, ser.GetRawQueryParams( qdata, [&ser, nsPtr = std::move(nsPtr)](int nsIdx) { - const uint32_t stateToken = ser.GetVarUint(); - const int version = ser.GetVarUint(); + const uint32_t stateToken = ser.GetVarUInt(); + const int version = ser.GetVarUInt(); TagsMatcher newTm; newTm.deserialize(ser, version, stateToken); if (nsIdx != 0) { @@ -504,7 +503,7 @@ Error RPCClient::UpdateIndex(std::string_view nsName, const IndexDef& iDef, cons } Error RPCClient::DropIndex(std::string_view nsName, const IndexDef& idx, const InternalRdxContext& ctx) { - return conn_.Call(mkCommand(cproto::kCmdDropIndex, &ctx), nsName, idx.name_).Status(); + return conn_.Call(mkCommand(cproto::kCmdDropIndex, &ctx), nsName, idx.Name()).Status(); } Error RPCClient::SetSchema(std::string_view nsName, std::string_view schema, const InternalRdxContext& ctx) { diff --git a/cpp_src/client/rpcclient.h b/cpp_src/client/rpcclient.h index d99ff448e..b70411427 100644 --- a/cpp_src/client/rpcclient.h +++ b/cpp_src/client/rpcclient.h @@ -1,4 +1,4 @@ -#pragma once +#pragma once #include "client/coroqueryresults.h" #include "client/corotransaction.h" @@ -28,6 +28,7 @@ struct ShardingControlResponseData; } // namespace sharding namespace client { +struct ConnectOpts; class Snapshot; template diff --git a/cpp_src/cluster/clustercontrolrequest.cc b/cpp_src/cluster/clustercontrolrequest.cc index e1ab307be..c9fae58ce 100644 --- a/cpp_src/cluster/clustercontrolrequest.cc +++ b/cpp_src/cluster/clustercontrolrequest.cc @@ -12,7 +12,7 @@ void ClusterControlRequestData::GetJSON(WrSerializer& ser) const { std::visit([&payloadBuilder](const auto& d) { d.GetJSON(payloadBuilder); }, data); } } -Error ClusterControlRequestData::FromJSON(span json) { +Error ClusterControlRequestData::FromJSON(std::span json) { try { gason::JsonParser parser; auto node = parser.Parse(json); diff --git a/cpp_src/cluster/clustercontrolrequest.h b/cpp_src/cluster/clustercontrolrequest.h index d58379ca9..832ecbf47 100644 --- a/cpp_src/cluster/clustercontrolrequest.h +++ b/cpp_src/cluster/clustercontrolrequest.h @@ -1,6 +1,6 @@ #pragma once +#include #include -#include "estl/span.h" #include "tools/errors.h" #include "tools/serializer.h" @@ -24,7 +24,7 @@ struct ClusterControlRequestData { ClusterControlRequestData() = default; ClusterControlRequestData(SetClusterLeaderCommand&& value) : type(Type::ChangeLeader), data(std::move(value)) {} void GetJSON(WrSerializer& ser) const; - Error FromJSON(span json); + Error FromJSON(std::span json); Type type = Type::Empty; std::variant data; diff --git a/cpp_src/cluster/clusterizator.cc b/cpp_src/cluster/clusterizator.cc index 3bcb4b3e7..e7b562188 100644 --- a/cpp_src/cluster/clusterizator.cc +++ b/cpp_src/cluster/clusterizator.cc @@ -144,6 +144,14 @@ Error Clusterizator::ReplicateAsync(UpdatesContainer&& recs, const RdxContext& c return {}; // This namespace is not taking part in any replication } +ReplicationStats Clusterizator::GetAsyncReplicationStats() const { return asyncReplicator_.GetReplicationStats(); } + +ReplicationStats Clusterizator::GetClusterReplicationStats() const { return clusterReplicator_.GetReplicationStats(); } + +void Clusterizator::SetAsyncReplicatonLogLevel(LogLevel level) noexcept { asyncReplicator_.SetLogLevel(level); } + +void Clusterizator::SetClusterReplicatonLogLevel(LogLevel level) noexcept { clusterReplicator_.SetLogLevel(level); } + bool Clusterizator::replicationIsNotRequired(const UpdatesContainer& recs) noexcept { return recs.empty() || isSystemNamespaceNameFast(recs[0].NsName()); } diff --git a/cpp_src/cluster/clusterizator.h b/cpp_src/cluster/clusterizator.h index 276da3737..6bf362162 100644 --- a/cpp_src/cluster/clusterizator.h +++ b/cpp_src/cluster/clusterizator.h @@ -54,10 +54,10 @@ class Clusterizator : public IDataReplicator, public IDataSyncer { bool IsInitialSyncDone() const override final { return !enabled_.load(std::memory_order_acquire) || sharedSyncState_.IsInitialSyncDone(); } - ReplicationStats GetAsyncReplicationStats() const { return asyncReplicator_.GetReplicationStats(); } - ReplicationStats GetClusterReplicationStats() const { return clusterReplicator_.GetReplicationStats(); } - void SetAsyncReplicatonLogLevel(LogLevel level) noexcept { asyncReplicator_.SetLogLevel(level); } - void SetClusterReplicatonLogLevel(LogLevel level) noexcept { clusterReplicator_.SetLogLevel(level); } + ReplicationStats GetAsyncReplicationStats() const; + ReplicationStats GetClusterReplicationStats() const; + void SetAsyncReplicatonLogLevel(LogLevel level) noexcept; + void SetClusterReplicatonLogLevel(LogLevel level) noexcept; private: static bool replicationIsNotRequired(const UpdatesContainer& recs) noexcept; @@ -65,7 +65,7 @@ class Clusterizator : public IDataReplicator, public IDataSyncer { mutable std::mutex mtx_; UpdatesQueuePair updatesQueue_; - SharedSyncState<> sharedSyncState_; + SharedSyncState sharedSyncState_; ClusterDataReplicator clusterReplicator_; AsyncDataReplicator asyncReplicator_; net::ev::async terminateAsync_; diff --git a/cpp_src/cluster/config.cc b/cpp_src/cluster/config.cc index 75465727b..9ab3e62c1 100644 --- a/cpp_src/cluster/config.cc +++ b/cpp_src/cluster/config.cc @@ -12,8 +12,7 @@ using namespace std::string_view_literals; -namespace reindexer { -namespace cluster { +namespace reindexer::cluster { static void ValidateDSN(const DSN& dsn) { if (dsn.Parser().scheme() != "cproto" && dsn.Parser().scheme() != "cprotos") { @@ -21,7 +20,7 @@ static void ValidateDSN(const DSN& dsn) { } } -Error NodeData::FromJSON(span json) { +Error NodeData::FromJSON(std::span json) { try { gason::JsonParser parser; return FromJSON(parser.Parse(json)); @@ -57,7 +56,7 @@ void NodeData::GetJSON(WrSerializer& ser) const { GetJSON(jb); } -Error RaftInfo::FromJSON(span json) { +Error RaftInfo::FromJSON(std::span json) { try { gason::JsonParser parser; return FromJSON(parser.Parse(json)); @@ -661,11 +660,11 @@ sharding::Segment ShardingConfig::Key::SegmentFromYAML(const YAML::Node sharding::Segment ShardingConfig::Key::SegmentFromJSON(const gason::JsonNode& json) { const auto& jsonValue = json.value; switch (jsonValue.getTag()) { - case gason::JsonTag::JSON_TRUE: - case gason::JsonTag::JSON_FALSE: - case gason::JsonTag::JSON_STRING: - case gason::JsonTag::JSON_DOUBLE: - case gason::JsonTag::JSON_NUMBER: { + case gason::JsonTag::JTRUE: + case gason::JsonTag::JFALSE: + case gason::JsonTag::STRING: + case gason::JsonTag::DOUBLE: + case gason::JsonTag::NUMBER: { auto val = stringToVariant(stringifyJson(json, false)); if (val.Type().Is()) { @@ -674,7 +673,7 @@ sharding::Segment ShardingConfig::Key::SegmentFromJSON(const gason::Jso return sharding::Segment{val, val}; } - case gason::JsonTag::JSON_OBJECT: { + case gason::JsonTag::OBJECT: { algorithmType = ByRange; const auto& range = json["range"]; if (auto dist = std::distance(begin(range), end(range)); dist != 2) { @@ -691,9 +690,9 @@ sharding::Segment ShardingConfig::Key::SegmentFromJSON(const gason::Jso return sharding::Segment{std::move(left), std::move(right)}; } - case gason::JsonTag::JSON_ARRAY: + case gason::JsonTag::ARRAY: case gason::JsonTag::JSON_NULL: - case gason::JsonTag::JSON_EMPTY: + case gason::JsonTag::EMPTY: default: throw Error(errParams, "Incorrect JsonTag for sharding key: %d", int(jsonValue.getTag())); } @@ -907,7 +906,18 @@ Error ShardingConfig::FromYAML(const std::string& yaml) { } } -Error ShardingConfig::FromJSON(span json) { +Error ShardingConfig::FromJSON(std::string_view json) { + try { + gason::JsonParser parser; + return FromJSON(parser.Parse(json)); + } catch (const gason::Exception& ex) { + return Error(errParseJson, "ShardingConfig: %s", ex.what()); + } catch (const Error& err) { + return err; + } +} + +Error ShardingConfig::FromJSON(std::span json) { try { gason::JsonParser parser; return FromJSON(parser.Parse(json)); @@ -1099,5 +1109,4 @@ Error ShardingConfig::Validate() const { return {}; } -} // namespace cluster -} // namespace reindexer +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/config.h b/cpp_src/cluster/config.h index eb801ae13..9a3b1fe81 100644 --- a/cpp_src/cluster/config.h +++ b/cpp_src/cluster/config.h @@ -2,9 +2,9 @@ #include #include +#include #include "core/keyvalue/variant.h" #include "core/namespace/namespacenamesets.h" -#include "estl/span.h" #include "sharding/ranges.h" #include "tools/dsn.h" #include "tools/errors.h" @@ -34,7 +34,7 @@ struct NodeData { int electionsTerm = 0; DSN dsn; - Error FromJSON(span json); + Error FromJSON(std::span json); Error FromJSON(const gason::JsonNode& v); void GetJSON(JsonBuilder& jb) const; void GetJSON(WrSerializer& ser) const; @@ -48,7 +48,7 @@ struct RaftInfo { bool operator==(const RaftInfo& rhs) const noexcept { return role == rhs.role && leaderId == rhs.leaderId; } bool operator!=(const RaftInfo& rhs) const noexcept { return !(*this == rhs); } - Error FromJSON(span json); + Error FromJSON(std::span json); Error FromJSON(const gason::JsonNode& root); void GetJSON(JsonBuilder& jb) const; void GetJSON(WrSerializer& ser) const; @@ -229,7 +229,8 @@ struct ShardingConfig { }; Error FromYAML(const std::string& yaml); - Error FromJSON(span json); + Error FromJSON(std::string_view json); + Error FromJSON(std::span json); Error FromJSON(const gason::JsonNode&); std::string GetYAML() const; YAML::Node GetYAMLObj() const; diff --git a/cpp_src/cluster/consts.h b/cpp_src/cluster/consts.h index 9c4a1617d..f46c64e8e 100644 --- a/cpp_src/cluster/consts.h +++ b/cpp_src/cluster/consts.h @@ -6,8 +6,8 @@ namespace reindexer { namespace cluster { -const std::string_view kAsyncReplStatsType = "async"; -const std::string_view kClusterReplStatsType = "cluster"; +constexpr std::string_view kAsyncReplStatsType = "async"; +constexpr std::string_view kClusterReplStatsType = "cluster"; constexpr auto kLeaderPingInterval = std::chrono::milliseconds(200); constexpr auto kMinLeaderAwaitInterval = kLeaderPingInterval * 5; diff --git a/cpp_src/cluster/replication/asyncdatareplicator.cc b/cpp_src/cluster/replication/asyncdatareplicator.cc index fc9b2a7d1..e5b5e11b1 100644 --- a/cpp_src/cluster/replication/asyncdatareplicator.cc +++ b/cpp_src/cluster/replication/asyncdatareplicator.cc @@ -5,7 +5,7 @@ namespace reindexer { namespace cluster { -AsyncDataReplicator::AsyncDataReplicator(AsyncDataReplicator::UpdatesQueueT& q, SharedSyncState<>& syncState, ReindexerImpl& thisNode, +AsyncDataReplicator::AsyncDataReplicator(AsyncDataReplicator::UpdatesQueueT& q, SharedSyncState& syncState, ReindexerImpl& thisNode, Clusterizator& clusterizator) : statsCollector_(std::string(kAsyncReplStatsType)), updatesQueue_(q), @@ -114,6 +114,10 @@ bool AsyncDataReplicator::isExpectingStartup() const noexcept { config_->role != AsyncReplConfigData::Role::None; } +size_t AsyncDataReplicator::threadsCount() const noexcept { + return config_.has_value() && config_->replThreadsCount > 0 ? config_->replThreadsCount : kDefaultReplThreadCount; +} + void AsyncDataReplicator::stop() { if (isRunning()) { for (auto& th : replThreads_) { diff --git a/cpp_src/cluster/replication/asyncdatareplicator.h b/cpp_src/cluster/replication/asyncdatareplicator.h index 40f855786..f89a0da2c 100644 --- a/cpp_src/cluster/replication/asyncdatareplicator.h +++ b/cpp_src/cluster/replication/asyncdatareplicator.h @@ -14,7 +14,7 @@ class AsyncDataReplicator { public: using UpdatesQueueT = UpdatesQueuePair; - AsyncDataReplicator(UpdatesQueueT&, SharedSyncState<>&, ReindexerImpl&, Clusterizator&); + AsyncDataReplicator(UpdatesQueueT&, SharedSyncState&, ReindexerImpl&, Clusterizator&); void Configure(AsyncReplConfigData config); void Configure(ReplicationConfigData config); @@ -29,9 +29,7 @@ class AsyncDataReplicator { private: static constexpr std::string_view logModuleName() noexcept { return std::string_view("asyncreplicator"); } bool isExpectingStartup() const noexcept; - size_t threadsCount() const noexcept { - return config_.has_value() && config_->replThreadsCount > 0 ? config_->replThreadsCount : kDefaultReplThreadCount; - } + size_t threadsCount() const noexcept; bool isRunning() const noexcept { return replThreads_.size(); } void stop(); NsNamesHashSetT getLocalNamespaces(); @@ -40,7 +38,7 @@ class AsyncDataReplicator { ReplicationStatsCollector statsCollector_; mutable std::mutex mtx_; UpdatesQueueT& updatesQueue_; - SharedSyncState<>& syncState_; + SharedSyncState& syncState_; ReindexerImpl& thisNode_; Clusterizator& clusterizator_; std::deque replThreads_; diff --git a/cpp_src/cluster/replication/asyncreplthread.cc b/cpp_src/cluster/replication/asyncreplthread.cc index e0a9f3980..573a6d82b 100644 --- a/cpp_src/cluster/replication/asyncreplthread.cc +++ b/cpp_src/cluster/replication/asyncreplthread.cc @@ -1,11 +1,38 @@ #include "asyncreplthread.h" +#include "sharedsyncstate.h" +#include "tools/catch_and_return.h" -namespace reindexer { -namespace cluster { +namespace reindexer::cluster { + +AsyncThreadParam::AsyncThreadParam(const std::vector *n, AsyncReplicationMode replMode, SharedSyncState &syncState) + : nodes_(n), replMode_(replMode), syncState_(syncState) { + assert(nodes_); + } + +Error AsyncThreadParam::CheckReplicationMode(uint32_t nodeId) const noexcept { + try { + auto replMode = replMode_; + const auto& nodeReplMode = (*nodes_)[nodeId].GetReplicationMode(); + if (nodeReplMode.has_value()) { + replMode = nodeReplMode.value(); + } + if (replMode == AsyncReplicationMode::FromClusterLeader) { + const auto rp = syncState_.GetRolesPair(); + if (rp.first.role != rp.second.role || (rp.first.role != RaftInfo::Role::Leader && rp.first.role != RaftInfo::Role::None)) { + return Error(errParams, + "Current node has roles '%s:%s', but role 'leader' (or 'none') is required to replicate, when " + "replication mode set to 'from_sync_leader'", + RaftInfo::RoleToStr(rp.first.role), RaftInfo::RoleToStr(rp.second.role)); + } + } + } + CATCH_AND_RETURN; + return Error(); +} AsyncReplThread::AsyncReplThread(int serverId, ReindexerImpl& thisNode, std::shared_ptr q, const std::vector& nodesList, AsyncReplicationMode replMode, - SharedSyncState<>& syncState, ReplicationStatsCollector statsCollector, const Logger& l) + SharedSyncState& syncState, ReplicationStatsCollector statsCollector, const Logger& l) : base_(serverId, thisNode, std::move(q), AsyncThreadParam(&nodesList, replMode, syncState), statsCollector, l) {} AsyncReplThread::~AsyncReplThread() { @@ -31,5 +58,4 @@ void AsyncReplThread::AwaitTermination() { base_.SetTerminate(false); } -} // namespace cluster -} // namespace reindexer +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/replication/asyncreplthread.h b/cpp_src/cluster/replication/asyncreplthread.h index 7e14332dc..47e0f9ce6 100644 --- a/cpp_src/cluster/replication/asyncreplthread.h +++ b/cpp_src/cluster/replication/asyncreplthread.h @@ -2,15 +2,13 @@ #include "replicationthread.h" -namespace reindexer { -namespace cluster { +namespace reindexer::cluster { + +class SharedSyncState; class AsyncThreadParam { public: - AsyncThreadParam(const std::vector* n, AsyncReplicationMode replMode, SharedSyncState<>& syncState) - : nodes_(n), replMode_(replMode), syncState_(syncState) { - assert(nodes_); - } + AsyncThreadParam(const std::vector* n, AsyncReplicationMode replMode, SharedSyncState& syncState); AsyncThreadParam(AsyncThreadParam&& o) = default; AsyncThreadParam(const AsyncThreadParam& o) = default; @@ -23,35 +21,19 @@ class AsyncThreadParam { void OnNodeBecameUnsynchonized(uint32_t) const noexcept {} void OnAllUpdatesReplicated(uint32_t, int64_t) const noexcept {} void OnUpdateSucceed(uint32_t, int64_t) const noexcept {} - Error CheckReplicationMode(uint32_t nodeId) const noexcept { - auto replMode = replMode_; - const auto& nodeReplMode = (*nodes_)[nodeId].GetReplicationMode(); - if (nodeReplMode.has_value()) { - replMode = nodeReplMode.value(); - } - if (replMode == AsyncReplicationMode::FromClusterLeader) { - const auto rp = syncState_.GetRolesPair(); - if (rp.first.role != rp.second.role || (rp.first.role != RaftInfo::Role::Leader && rp.first.role != RaftInfo::Role::None)) { - return Error(errParams, - "Current node has roles '%s:%s', but role 'leader' (or 'none') is required to replicate, when " - "replication mode set to 'from_sync_leader'", - RaftInfo::RoleToStr(rp.first.role), RaftInfo::RoleToStr(rp.second.role)); - } - } - return Error(); - } + Error CheckReplicationMode(uint32_t nodeId) const noexcept; private: const std::vector* nodes_; AsyncReplicationMode replMode_; - SharedSyncState<>& syncState_; + SharedSyncState& syncState_; }; class AsyncReplThread { public: using BaseT = ReplThread; AsyncReplThread(int serverId, ReindexerImpl& thisNode, std::shared_ptr, - const std::vector& nodesList, AsyncReplicationMode, SharedSyncState<>&, ReplicationStatsCollector, + const std::vector& nodesList, AsyncReplicationMode, SharedSyncState&, ReplicationStatsCollector, const Logger&); ~AsyncReplThread(); void Run(ReplThreadConfig config, std::vector >&& nodesList, size_t totalNodesCount); @@ -63,5 +45,4 @@ class AsyncReplThread { BaseT base_; }; -} // namespace cluster -} // namespace reindexer +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/replication/clusterdatareplicator.cc b/cpp_src/cluster/replication/clusterdatareplicator.cc index 7a1dc4526..06121fa07 100644 --- a/cpp_src/cluster/replication/clusterdatareplicator.cc +++ b/cpp_src/cluster/replication/clusterdatareplicator.cc @@ -1,12 +1,12 @@ #include "clusterdatareplicator.h" -#include "core/defnsconfigs.h" #include "core/reindexer_impl/reindexerimpl.h" +#include "core/system_ns_names.h" #include "tools/randomgenerator.h" namespace reindexer { namespace cluster { -ClusterDataReplicator::ClusterDataReplicator(ClusterDataReplicator::UpdatesQueueT& q, SharedSyncState<>& s, ReindexerImpl& thisNode) +ClusterDataReplicator::ClusterDataReplicator(ClusterDataReplicator::UpdatesQueueT& q, SharedSyncState& s, ReindexerImpl& thisNode) : statsCollector_(std::string(kClusterReplStatsType)), raftManager_(loop_, statsCollector_, log_, [this](uint32_t uid, bool online) { diff --git a/cpp_src/cluster/replication/clusterdatareplicator.h b/cpp_src/cluster/replication/clusterdatareplicator.h index ac4f4fbda..42fe4a9b4 100644 --- a/cpp_src/cluster/replication/clusterdatareplicator.h +++ b/cpp_src/cluster/replication/clusterdatareplicator.h @@ -16,7 +16,7 @@ class ClusterDataReplicator { using UpdatesQueueT = UpdatesQueuePair; using UpdatesQueueShardT = UpdatesQueueT::QueueT; - ClusterDataReplicator(UpdatesQueueT&, SharedSyncState<>&, ReindexerImpl&); + ClusterDataReplicator(UpdatesQueueT&, SharedSyncState&, ReindexerImpl&); void Configure(ClusterConfigData config); void Configure(ReplicationConfigData config); @@ -98,7 +98,7 @@ class ClusterDataReplicator { Logger log_; RaftManager raftManager_; UpdatesQueueT& updatesQueue_; - SharedSyncState<>& sharedSyncState_; + SharedSyncState& sharedSyncState_; ReindexerImpl& thisNode_; std::deque replThreads_; std::function requestElectionsRestartCb_; diff --git a/cpp_src/cluster/replication/clusterreplthread.cc b/cpp_src/cluster/replication/clusterreplthread.cc index 8e24cccde..2180e3d06 100644 --- a/cpp_src/cluster/replication/clusterreplthread.cc +++ b/cpp_src/cluster/replication/clusterreplthread.cc @@ -1,12 +1,30 @@ #include "clusterreplthread.h" #include "core/reindexer_impl/reindexerimpl.h" +#include "sharedsyncstate.h" -namespace reindexer { -namespace cluster { +namespace reindexer::cluster { + +ClusterThreadParam::ClusterThreadParam(const NsNamesHashSetT* namespaces, coroutine::channel& ch, SharedSyncState& st, + SynchronizationList& syncList, std::function cb) + : namespaces_(namespaces), + leadershipAwaitCh_(ch), + sharedSyncState_(st), + requestElectionsRestartCb_(std::move(cb)), + syncList_(syncList) { + assert(namespaces_); +} + +void ClusterThreadParam::OnNewNsAppearance(const NamespaceName& ns) { sharedSyncState_.MarkSynchronized(ns); } + +void ClusterThreadParam::OnUpdateReplicationFailure() { + if (sharedSyncState_.GetRolesPair().second.role == RaftInfo::Role::Leader) { + requestElectionsRestartCb_(); + } +} ClusterReplThread::ClusterReplThread(int serverId, ReindexerImpl& thisNode, const NsNamesHashSetT* namespaces, std::shared_ptr> q, - SharedSyncState<>& syncState, SynchronizationList& syncList, + SharedSyncState& syncState, SynchronizationList& syncList, std::function requestElectionsRestartCb, ReplicationStatsCollector statsCollector, const Logger& l) : base_(serverId, thisNode, std::move(q), @@ -69,5 +87,4 @@ void ClusterReplThread::AwaitTermination() { void ClusterReplThread::OnRoleSwitch() { roleSwitchAsync_.send(); } -} // namespace cluster -} // namespace reindexer +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/replication/clusterreplthread.h b/cpp_src/cluster/replication/clusterreplthread.h index 4f948cd7d..36614ce07 100644 --- a/cpp_src/cluster/replication/clusterreplthread.h +++ b/cpp_src/cluster/replication/clusterreplthread.h @@ -3,29 +3,19 @@ #include "cluster/stats/synchronizationlist.h" #include "replicationthread.h" -namespace reindexer { -namespace cluster { +namespace reindexer::cluster { + +class SharedSyncState; class ClusterThreadParam { public: - ClusterThreadParam(const NsNamesHashSetT* namespaces, coroutine::channel& ch, SharedSyncState<>& st, - SynchronizationList& syncList, std::function cb) - : namespaces_(namespaces), - leadershipAwaitCh_(ch), - sharedSyncState_(st), - requestElectionsRestartCb_(std::move(cb)), - syncList_(syncList) { - assert(namespaces_); - } + ClusterThreadParam(const NsNamesHashSetT* namespaces, coroutine::channel& ch, SharedSyncState& st, SynchronizationList& syncList, + std::function cb); bool IsLeader() const noexcept { return !leadershipAwaitCh_.opened(); } void AwaitReplPermission() { leadershipAwaitCh_.pop(); } - void OnNewNsAppearance(const NamespaceName& ns) { sharedSyncState_.MarkSynchronized(ns); } - void OnUpdateReplicationFailure() { - if (sharedSyncState_.GetRolesPair().second.role == RaftInfo::Role::Leader) { - requestElectionsRestartCb_(); - } - } + void OnNewNsAppearance(const NamespaceName& ns); + void OnUpdateReplicationFailure(); bool IsNamespaceInConfig(size_t, const NamespaceName& ns) const noexcept { return namespaces_->empty() || (namespaces_->find(ns) != namespaces_->end()); } @@ -40,7 +30,7 @@ class ClusterThreadParam { private: const NsNamesHashSetT* namespaces_; coroutine::channel& leadershipAwaitCh_; - SharedSyncState<>& sharedSyncState_; + SharedSyncState& sharedSyncState_; std::function requestElectionsRestartCb_; SynchronizationList& syncList_; }; @@ -48,7 +38,7 @@ class ClusterThreadParam { class ClusterReplThread { public: ClusterReplThread(int serverId, ReindexerImpl& thisNode, const NsNamesHashSetT*, - std::shared_ptr>, SharedSyncState<>&, + std::shared_ptr>, SharedSyncState&, SynchronizationList&, std::function requestElectionsRestartCb, ReplicationStatsCollector, const Logger&); ~ClusterReplThread(); void Run(ReplThreadConfig config, std::vector>&& nodesList, size_t totalNodesCount); @@ -61,9 +51,8 @@ class ClusterReplThread { coroutine::channel leadershipAwaitCh; net::ev::async roleSwitchAsync_; ReplThread base_; - SharedSyncState<>& sharedSyncState_; + SharedSyncState& sharedSyncState_; steady_clock_w::time_point roleSwitchTm_; }; -} // namespace cluster -} // namespace reindexer +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/replication/leadersyncer.cc b/cpp_src/cluster/replication/leadersyncer.cc index 9f72588cc..cb03038bf 100644 --- a/cpp_src/cluster/replication/leadersyncer.cc +++ b/cpp_src/cluster/replication/leadersyncer.cc @@ -2,13 +2,14 @@ #include "client/snapshot.h" #include "cluster/logger.h" #include "cluster/sharding/shardingcontrolrequest.h" -#include "core/defnsconfigs.h" #include "core/reindexer_impl/reindexerimpl.h" +#include "estl/gift_str.h" +#include "vendor/gason/gason.h" namespace reindexer { namespace cluster { -Error LeaderSyncer::Sync(std::list&& entries, SharedSyncState<>& sharedSyncState, ReindexerImpl& thisNode, +Error LeaderSyncer::Sync(elist&& entries, SharedSyncState& sharedSyncState, ReindexerImpl& thisNode, ReplicationStatsCollector statsCollector) { Error err; const LeaderSyncThread::Config thCfg{cfg_.dsns, cfg_.maxWALDepthOnForceSync, cfg_.clusterId, diff --git a/cpp_src/cluster/replication/leadersyncer.h b/cpp_src/cluster/replication/leadersyncer.h index 3092932db..128d9b694 100644 --- a/cpp_src/cluster/replication/leadersyncer.h +++ b/cpp_src/cluster/replication/leadersyncer.h @@ -1,11 +1,10 @@ #pragma once #include -#include #include "client/cororeindexer.h" #include "cluster/stats/relicationstatscollector.h" -#include "cluster/stats/synchronizationlist.h" #include "core/namespace/namespacestat.h" +#include "estl/elist.h" #include "net/ev/ev.h" #include "sharedsyncstate.h" #include "tools/lsn.h" @@ -46,7 +45,7 @@ class LeaderSyncQueue { LeaderSyncQueue(size_t maxSyncsPerNode) : maxSyncsPerNode_(maxSyncsPerNode) {} - void Refill(std::list&& entries) { + void Refill(elist&& entries) { std::lock_guard lck(mtx_); entries_ = std::move(entries); currentSyncsPerNode_.clear(); @@ -107,7 +106,7 @@ class LeaderSyncQueue { private: const size_t maxSyncsPerNode_; std::mutex mtx_; - std::list entries_; + elist entries_; std::map currentSyncsPerNode_; }; @@ -122,7 +121,7 @@ class LeaderSyncThread { std::chrono::milliseconds netTimeout; }; - LeaderSyncThread(const Config& cfg, LeaderSyncQueue& syncQueue, SharedSyncState<>& sharedSyncState, ReindexerImpl& thisNode, + LeaderSyncThread(const Config& cfg, LeaderSyncQueue& syncQueue, SharedSyncState& sharedSyncState, ReindexerImpl& thisNode, ReplicationStatsCollector statsCollector, const Logger& l, std::once_flag& actShardingCfg) : cfg_(cfg), syncQueue_(syncQueue), @@ -156,7 +155,7 @@ class LeaderSyncThread { LeaderSyncQueue& syncQueue_; Error lastError_; std::atomic terminate_ = false; - SharedSyncState<>& sharedSyncState_; + SharedSyncState& sharedSyncState_; ReindexerImpl& thisNode_; ReplicationStatsCollector statsCollector_; client::CoroReindexer client_; @@ -188,7 +187,7 @@ class LeaderSyncer { th.Terminate(); } } - Error Sync(std::list&& entries, SharedSyncState<>& sharedSyncState, ReindexerImpl& thisNode, + Error Sync(elist&& entries, SharedSyncState& sharedSyncState, ReindexerImpl& thisNode, ReplicationStatsCollector statsCollector); private: diff --git a/cpp_src/cluster/replication/replicationthread.cc b/cpp_src/cluster/replication/replicationthread.cc index 2bba774d7..675747d65 100644 --- a/cpp_src/cluster/replication/replicationthread.cc +++ b/cpp_src/cluster/replication/replicationthread.cc @@ -1,9 +1,11 @@ #include "asyncreplthread.h" +#include "cluster/consts.h" #include "cluster/sharding/shardingcontrolrequest.h" #include "clusterreplthread.h" -#include "core/defnsconfigs.h" #include "core/namespace/snapshot/snapshot.h" #include "core/reindexer_impl/reindexerimpl.h" +#include "estl/gift_str.h" +#include "sharedsyncstate.h" #include "tools/catch_and_return.h" #include "updates/updatesqueue.h" #include "updatesbatcher.h" @@ -28,6 +30,66 @@ using updates::SaveNewShardingCfgRecord; using updates::ApplyNewShardingCfgRecord; using updates::ResetShardingCfgRecord; +ReplThreadConfig::ReplThreadConfig(const ReplicationConfigData& baseConfig, const AsyncReplConfigData& config) { + AppName = config.appName; + EnableCompression = config.enableCompression; + UpdatesTimeoutSec = config.onlineUpdatesTimeoutSec; + RetrySyncIntervalMSec = config.retrySyncIntervalMSec; + ParallelSyncsPerThreadCount = config.parallelSyncsPerThreadCount; + BatchingRoutinesCount = config.batchingRoutinesCount > 0 ? size_t(config.batchingRoutinesCount) : 100; + MaxWALDepthOnForceSync = config.maxWALDepthOnForceSync; + SyncTimeoutSec = std::max(config.syncTimeoutSec, config.onlineUpdatesTimeoutSec); + ClusterID = baseConfig.clusterID; + if (config.onlineUpdatesDelayMSec > 0) { + OnlineUpdatesDelaySec = double(config.onlineUpdatesDelayMSec) / 1000.; + } else if (config.onlineUpdatesDelayMSec == 0) { + OnlineUpdatesDelaySec = 0; + } else { + OnlineUpdatesDelaySec = 0.1; + } +} + +ReplThreadConfig::ReplThreadConfig(const ReplicationConfigData& baseConfig, const ClusterConfigData& config) { + AppName = config.appName; + EnableCompression = config.enableCompression; + UpdatesTimeoutSec = config.onlineUpdatesTimeoutSec; + RetrySyncIntervalMSec = config.retrySyncIntervalMSec; + ParallelSyncsPerThreadCount = config.parallelSyncsPerThreadCount; + ClusterID = baseConfig.clusterID; + MaxWALDepthOnForceSync = config.maxWALDepthOnForceSync; + SyncTimeoutSec = std::max(config.syncTimeoutSec, config.onlineUpdatesTimeoutSec); + BatchingRoutinesCount = config.batchingRoutinesCount > 0 ? size_t(config.batchingRoutinesCount) : 100; + OnlineUpdatesDelaySec = 0; +} + +namespace repl_thread_impl { +void NamespaceData::UpdateLsnOnRecord(const updates::UpdateRecord& rec) { + if (!rec.IsDbRecord()) { + // Updates with *Namespace types have fake lsn. Those updates should not be count in latestLsn + latestLsn = rec.ExtLSN(); + } else if (rec.Type() == updates::URType::AddNamespace) { + if (latestLsn.NsVersion().isEmpty() || latestLsn.NsVersion().Counter() < rec.ExtLSN().NsVersion().Counter()) { + latestLsn = ExtendedLsn(rec.ExtLSN().NsVersion(), lsn_t()); + } + } else if (rec.Type() == updates::URType::DropNamespace) { + latestLsn = ExtendedLsn(); + } +} + +void Node::Reconnect(net::ev::dynamic_loop& loop, const ReplThreadConfig& config) { + if (connObserverId.has_value()) { + auto err = client.RemoveConnectionStateObserver(*connObserverId); + (void)err; // ignored + connObserverId.reset(); + } + client.Stop(); + client::ConnectOpts opts; + opts.CreateDBIfMissing().WithExpectedClusterID(config.ClusterID); + auto err = client.Connect(dsn, loop, opts); + (void)err; // ignored; Error will be checked during the further requests +} +} // namespace repl_thread_impl + template bool UpdateApplyStatus::IsHaveToResync() const noexcept { static_assert(std::is_same_v || std::is_same_v, @@ -179,6 +241,25 @@ void ReplThread::SetTerminate(bool val) noexcept { } } +template +void ReplThread::DisconnectNodes() { + coroutine::wait_group swg; + for (auto& node : nodes) { + loop.spawn( + swg, + [&node]() noexcept { + if (node.connObserverId.has_value()) { + auto err = node.client.RemoveConnectionStateObserver(*node.connObserverId); + (void)err; // ignore + node.connObserverId.reset(); + } + node.client.Stop(); + }, + k16kCoroStack); + } + swg.wait(); +} + template constexpr bool ReplThread::isClusterReplThread() noexcept { static_assert(std::is_same_v || std::is_same_v, @@ -444,7 +525,7 @@ Error ReplThread::nodeReplicationImpl(Node& node) { } localWg.wait(); if (!integralError.ok()) { - logWarn("%d:%d Unable to sync remote namespaces: %s", serverId_, node.uid, integralError.what()); + logError("%d:%d Unable to sync remote namespaces: %s", serverId_, node.uid, integralError.what()); return integralError; } updateNodeStatus(node.uid, NodeStats::Status::Online); @@ -685,7 +766,7 @@ UpdateApplyStatus ReplThread::nodeUpdatesHandlingLoop(Node& nod NamespaceData* nsData; uint16_t offset; }; - UpdatesChT& updatesNotifier = *node.updateNotifier; + auto& updatesNotifier = *node.updateNotifier; bool requireReelections = false; auto applyUpdateF = [this, &node](const UpdatesQueueT::UpdateT::Value& upd, Context& ctx) { @@ -886,7 +967,7 @@ UpdateApplyStatus ReplThread::nodeUpdatesHandlingLoop(Node& nod template bool ReplThread::handleUpdatesWithError(Node& node, const Error& err) { - UpdatesChT& updatesNotifier = *node.updateNotifier; + auto& updatesNotifier = *node.updateNotifier; UpdatesQueueT::UpdatePtr updatePtr; bool hadErrorOnLastUpdate = false; diff --git a/cpp_src/cluster/replication/replicationthread.h b/cpp_src/cluster/replication/replicationthread.h index 3c85a3b33..c02c2d265 100644 --- a/cpp_src/cluster/replication/replicationthread.h +++ b/cpp_src/cluster/replication/replicationthread.h @@ -1,6 +1,5 @@ #pragma once -#include #include "client/cororeindexer.h" #include "cluster/config.h" #include "cluster/logger.h" @@ -8,7 +7,6 @@ #include "core/dbconfig.h" #include "coroutine/tokens_pool.h" #include "net/ev/ev.h" -#include "sharedsyncstate.h" #include "updates/updaterecord.h" #include "updates/updatesqueue.h" @@ -22,36 +20,8 @@ constexpr size_t kUpdatesContainerOverhead = 48; struct ReplThreadConfig { ReplThreadConfig() = default; - ReplThreadConfig(const ReplicationConfigData& baseConfig, const AsyncReplConfigData& config) { - AppName = config.appName; - EnableCompression = config.enableCompression; - UpdatesTimeoutSec = config.onlineUpdatesTimeoutSec; - RetrySyncIntervalMSec = config.retrySyncIntervalMSec; - ParallelSyncsPerThreadCount = config.parallelSyncsPerThreadCount; - BatchingRoutinesCount = config.batchingRoutinesCount > 0 ? size_t(config.batchingRoutinesCount) : 100; - MaxWALDepthOnForceSync = config.maxWALDepthOnForceSync; - SyncTimeoutSec = std::max(config.syncTimeoutSec, config.onlineUpdatesTimeoutSec); - ClusterID = baseConfig.clusterID; - if (config.onlineUpdatesDelayMSec > 0) { - OnlineUpdatesDelaySec = double(config.onlineUpdatesDelayMSec) / 1000.; - } else if (config.onlineUpdatesDelayMSec == 0) { - OnlineUpdatesDelaySec = 0; - } else { - OnlineUpdatesDelaySec = 0.1; - } - } - ReplThreadConfig(const ReplicationConfigData& baseConfig, const ClusterConfigData& config) { - AppName = config.appName; - EnableCompression = config.enableCompression; - UpdatesTimeoutSec = config.onlineUpdatesTimeoutSec; - RetrySyncIntervalMSec = config.retrySyncIntervalMSec; - ParallelSyncsPerThreadCount = config.parallelSyncsPerThreadCount; - ClusterID = baseConfig.clusterID; - MaxWALDepthOnForceSync = config.maxWALDepthOnForceSync; - SyncTimeoutSec = std::max(config.syncTimeoutSec, config.onlineUpdatesTimeoutSec); - BatchingRoutinesCount = config.batchingRoutinesCount > 0 ? size_t(config.batchingRoutinesCount) : 100; - OnlineUpdatesDelaySec = 0; - } + ReplThreadConfig(const ReplicationConfigData& baseConfig, const AsyncReplConfigData& config); + ReplThreadConfig(const ReplicationConfigData& baseConfig, const ClusterConfigData& config); std::string AppName = "rx_node"; int UpdatesTimeoutSec = 20; @@ -74,60 +44,42 @@ struct UpdateApplyStatus { updates::URType type; }; -template -class ReplThread { +namespace repl_thread_impl { +class NamespaceData { public: - using UpdatesQueueT = updates::UpdatesQueue; - using UpdatesChT = coroutine::channel; + void UpdateLsnOnRecord(const updates::UpdateRecord& rec); - class NamespaceData { - public: - void UpdateLsnOnRecord(const updates::UpdateRecord& rec) { - if (!rec.IsDbRecord()) { - // Updates with *Namespace types have fake lsn. Those updates should not be count in latestLsn - latestLsn = rec.ExtLSN(); - } else if (rec.Type() == updates::URType::AddNamespace) { - if (latestLsn.NsVersion().isEmpty() || latestLsn.NsVersion().Counter() < rec.ExtLSN().NsVersion().Counter()) { - latestLsn = ExtendedLsn(rec.ExtLSN().NsVersion(), lsn_t()); - } - } else if (rec.Type() == updates::URType::DropNamespace) { - latestLsn = ExtendedLsn(); - } - } + ExtendedLsn latestLsn; + client::CoroTransaction tx; + bool requiresTmUpdate = true; + bool isClosed = false; +}; - ExtendedLsn latestLsn; - client::CoroTransaction tx; - bool requiresTmUpdate = true; - bool isClosed = false; - }; - - struct Node { - Node(int _serverId, uint32_t _uid, const client::ReindexerConfig& config) noexcept - : serverId(_serverId), uid(_uid), client(config) {} - void Reconnect(net::ev::dynamic_loop& loop, const ReplThreadConfig& config) { - if (connObserverId.has_value()) { - auto err = client.RemoveConnectionStateObserver(*connObserverId); - (void)err; // ignored - connObserverId.reset(); - } - client.Stop(); - client::ConnectOpts opts; - opts.CreateDBIfMissing().WithExpectedClusterID(config.ClusterID); - auto err = client.Connect(dsn, loop, opts); - (void)err; // ignored; Error will be checked during the further requests - } +struct Node { + using UpdatesChT = coroutine::channel; + + Node(int _serverId, uint32_t _uid, const client::ReindexerConfig& config) noexcept : serverId(_serverId), uid(_uid), client(config) {} + void Reconnect(net::ev::dynamic_loop& loop, const ReplThreadConfig& config); + + int serverId; + uint32_t uid; + DSN dsn; + client::CoroReindexer client; + std::unique_ptr updateNotifier = std::make_unique(); + std::unordered_map + namespaceData; // This map should not invalidate references + uint64_t nextUpdateId = 0; + bool requireResync = false; + std::optional connObserverId; +}; +} // namespace repl_thread_impl - int serverId; - uint32_t uid; - DSN dsn; - client::CoroReindexer client; - std::unique_ptr updateNotifier = std::make_unique(); - std::unordered_map - namespaceData; // This map should not invalidate references - uint64_t nextUpdateId = 0; - bool requireResync = false; - std::optional connObserverId; - }; +template +class ReplThread { +public: + using UpdatesQueueT = updates::UpdatesQueue; + using Node = repl_thread_impl::Node; + using NamespaceData = repl_thread_impl::NamespaceData; ReplThread(int serverId_, ReindexerImpl& thisNode, std::shared_ptr, BehaviourParamT&&, ReplicationStatsCollector, const Logger&); @@ -137,23 +89,7 @@ class ReplThread { size_t requiredReplicas); void SetTerminate(bool val) noexcept; bool Terminated() const noexcept { return terminate_; } - void DisconnectNodes() { - coroutine::wait_group swg; - for (auto& node : nodes) { - loop.spawn( - swg, - [&node]() noexcept { - if (node.connObserverId.has_value()) { - auto err = node.client.RemoveConnectionStateObserver(*node.connObserverId); - (void)err; // ignore - node.connObserverId.reset(); - } - node.client.Stop(); - }, - k16kCoroStack); - } - swg.wait(); - } + void DisconnectNodes(); void SetNodesRequireResync() { for (auto& node : nodes) { node.requireResync = true; diff --git a/cpp_src/cluster/replication/roleswitcher.cc b/cpp_src/cluster/replication/roleswitcher.cc index ddee522d9..983d7eab3 100644 --- a/cpp_src/cluster/replication/roleswitcher.cc +++ b/cpp_src/cluster/replication/roleswitcher.cc @@ -1,17 +1,14 @@ #include "roleswitcher.h" -#include "client/snapshot.h" #include "cluster/logger.h" #include "core/reindexer_impl/reindexerimpl.h" -#include "coroutine/tokens_pool.h" #include "net/cproto/cproto.h" -#include "tools/logger.h" namespace reindexer { namespace cluster { constexpr auto kLeaderNsResyncInterval = std::chrono::milliseconds(1000); -RoleSwitcher::RoleSwitcher(SharedSyncState<>& syncState, SynchronizationList& syncList, ReindexerImpl& thisNode, +RoleSwitcher::RoleSwitcher(SharedSyncState& syncState, SynchronizationList& syncList, ReindexerImpl& thisNode, const ReplicationStatsCollector& statsCollector, const Logger& l) : sharedSyncState_(syncState), thisNode_(thisNode), statsCollector_(statsCollector), syncList_(syncList), log_(l) { leaderResyncTimer_.set(loop_); @@ -55,6 +52,25 @@ void RoleSwitcher::Run(std::vector&& dsns, RoleSwitcher::Config&& cfg) { roleSwitchAsync_.stop(); } +void RoleSwitcher::OnRoleChanged() { + std::lock_guard lck(mtx_); + if (syncer_) { + syncer_->Terminate(); + } + roleSwitchAsync_.send(); +} + +void RoleSwitcher::SetTerminationFlag(bool val) noexcept { + std::lock_guard lck(mtx_); + terminate_ = val; + if (val) { + if (syncer_) { + syncer_->Terminate(); + } + roleSwitchAsync_.send(); + } +} + void RoleSwitcher::await() { awaitCh_.pop(); if (!isTerminated()) { @@ -62,6 +78,14 @@ void RoleSwitcher::await() { } } +void RoleSwitcher::notify() { + if (!awaitCh_.full()) { + awaitCh_.push(true); + } +} + +void RoleSwitcher::terminate() { awaitCh_.close(); } + void RoleSwitcher::handleRoleSwitch() { auto rolesPair = sharedSyncState_.GetRolesPair(); auto& newState = rolesPair.second; @@ -195,7 +219,7 @@ void RoleSwitcher::initialLeadersSync() { Error lastErr; coroutine::wait_group leaderSyncWg; std::vector synchronizedNodes(nodes_.size(), SynchronizationList::kEmptyID); - std::list syncQueue; + elist syncQueue; NsNamesHashSetT nsList; if (cfg_.namespaces.size()) { @@ -298,7 +322,7 @@ Error RoleSwitcher::awaitRoleSwitchForNamespace(client::CoroReindexer& client, c } while (true); } -Error RoleSwitcher::getNodesListForNs(const NamespaceName& nsName, std::list& syncQueue) { +Error RoleSwitcher::getNodesListForNs(const NamespaceName& nsName, elist& syncQueue) { // 1) Find most recent data among all the followers LeaderSyncQueue::Entry nsEntry; nsEntry.nsName = nsName; diff --git a/cpp_src/cluster/replication/roleswitcher.h b/cpp_src/cluster/replication/roleswitcher.h index 57df7fd00..bb62176f1 100644 --- a/cpp_src/cluster/replication/roleswitcher.h +++ b/cpp_src/cluster/replication/roleswitcher.h @@ -30,26 +30,11 @@ class RoleSwitcher { int maxConcurrentSnapshotsPerNode = -1; }; - RoleSwitcher(SharedSyncState<>&, SynchronizationList&, ReindexerImpl&, const ReplicationStatsCollector&, const Logger&); + RoleSwitcher(SharedSyncState&, SynchronizationList&, ReindexerImpl&, const ReplicationStatsCollector&, const Logger&); void Run(std::vector&& dsns, RoleSwitcher::Config&& cfg); - void OnRoleChanged() { - std::lock_guard lck(mtx_); - if (syncer_) { - syncer_->Terminate(); - } - roleSwitchAsync_.send(); - } - void SetTerminationFlag(bool val) noexcept { - std::lock_guard lck(mtx_); - terminate_ = val; - if (val) { - if (syncer_) { - syncer_->Terminate(); - } - roleSwitchAsync_.send(); - } - } + void OnRoleChanged(); + void SetTerminationFlag(bool val) noexcept; private: struct Node { @@ -59,19 +44,15 @@ class RoleSwitcher { static constexpr std::string_view logModuleName() noexcept { return std::string_view("roleswitcher"); } void await(); - void notify() { - if (!awaitCh_.full()) { - awaitCh_.push(true); - } - } - void terminate() { awaitCh_.close(); } + void notify(); + void terminate(); void handleRoleSwitch(); template void switchNamespaces(const RaftInfo& state, const ContainerT& namespaces); void handleInitialSync(RaftInfo::Role newRole); void initialLeadersSync(); Error awaitRoleSwitchForNamespace(client::CoroReindexer& client, const NamespaceName& nsName, ReplicationStateV2& st); - Error getNodesListForNs(const NamespaceName& nsName, std::list& syncQueue); + Error getNodesListForNs(const NamespaceName& nsName, elist& syncQueue); NsNamesHashSetT collectNsNames(); template Error appendNsNamesFrom(RxT& rx, NsNamesHashSetT& set); @@ -82,7 +63,7 @@ class RoleSwitcher { std::vector nodes_; net::ev::dynamic_loop loop_; - SharedSyncState<>& sharedSyncState_; + SharedSyncState& sharedSyncState_; ReindexerImpl& thisNode_; ReplicationStatsCollector statsCollector_; steady_clock_w::time_point roleSwitchTm_; diff --git a/cpp_src/cluster/replication/sharedsyncstate.cc b/cpp_src/cluster/replication/sharedsyncstate.cc new file mode 100644 index 000000000..6ff064ea1 --- /dev/null +++ b/cpp_src/cluster/replication/sharedsyncstate.cc @@ -0,0 +1,67 @@ +#include "sharedsyncstate.h" + +namespace reindexer::cluster { + +void SharedSyncState::MarkSynchronized(NamespaceName name) { + std::unique_lock lck(mtx_); + assertrx_dbg(!name.empty()); + if (current_.role == RaftInfo::Role::Leader) { + auto res = synchronized_.emplace(std::move(name)); + lck.unlock(); + if (res.second) { + cond_.notify_all(); + } + } +} + +void SharedSyncState::MarkSynchronized() { + std::unique_lock lck(mtx_); + if (current_.role == RaftInfo::Role::Leader) { + ++initialSyncDoneCnt_; + lck.unlock(); + cond_.notify_all(); + } +} + +void SharedSyncState::Reset(ContainerT requireSynchronization, size_t ReplThreadsCnt, bool enabled) { + std::lock_guard lck(mtx_); + requireSynchronization_ = std::move(requireSynchronization); + synchronized_.clear(); + enabled_ = enabled; + terminated_ = false; + initialSyncDoneCnt_ = 0; + ReplThreadsCnt_ = ReplThreadsCnt; + next_ = current_ = RaftInfo(); + assert(ReplThreadsCnt_); +} + +RaftInfo SharedSyncState::TryTransitRole(RaftInfo expected) { + std::unique_lock lck(mtx_); + if (expected == next_) { + if (current_.role == RaftInfo::Role::Leader && current_.role != next_.role) { + synchronized_.clear(); + initialSyncDoneCnt_ = 0; + } + current_ = next_; + lck.unlock(); + cond_.notify_all(); + return expected; + } + return next_; +} + +void SharedSyncState::SetRole(RaftInfo info) { + std::lock_guard lck(mtx_); + next_ = std::move(info); +} + +void SharedSyncState::SetTerminated() { + { + std::lock_guard lck(mtx_); + terminated_ = true; + next_ = current_ = RaftInfo(); + } + cond_.notify_all(); +} + +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/replication/sharedsyncstate.h b/cpp_src/cluster/replication/sharedsyncstate.h index d40f0a3e2..c67b9c622 100644 --- a/cpp_src/cluster/replication/sharedsyncstate.h +++ b/cpp_src/cluster/replication/sharedsyncstate.h @@ -3,50 +3,22 @@ #include "cluster/config.h" #include "core/namespace/namespacename.h" #include "estl/contexted_cond_var.h" -#include "estl/fast_hash_set.h" #include "estl/shared_mutex.h" -namespace reindexer { -namespace cluster { +namespace reindexer::cluster { static constexpr size_t k16kCoroStack = 16 * 1024; -template class SharedSyncState { + using MtxT = shared_timed_mutex; + public: using GetNameF = std::function; using ContainerT = NsNamesHashSetT; - void MarkSynchronized(NamespaceName name) { - std::unique_lock lck(mtx_); - assertrx_dbg(!name.empty()); - if (current_.role == RaftInfo::Role::Leader) { - auto res = synchronized_.emplace(std::move(name)); - lck.unlock(); - if (res.second) { - cond_.notify_all(); - } - } - } - void MarkSynchronized() { - std::unique_lock lck(mtx_); - if (current_.role == RaftInfo::Role::Leader) { - ++initialSyncDoneCnt_; - lck.unlock(); - cond_.notify_all(); - } - } - void Reset(ContainerT requireSynchronization, size_t ReplThreadsCnt, bool enabled) { - std::lock_guard lck(mtx_); - requireSynchronization_ = std::move(requireSynchronization); - synchronized_.clear(); - enabled_ = enabled; - terminated_ = false; - initialSyncDoneCnt_ = 0; - ReplThreadsCnt_ = ReplThreadsCnt; - next_ = current_ = RaftInfo(); - assert(ReplThreadsCnt_); - } + void MarkSynchronized(NamespaceName name); + void MarkSynchronized(); + void Reset(ContainerT requireSynchronization, size_t ReplThreadsCnt, bool enabled); template void AwaitInitialSync(const NamespaceName& name, const ContextT& ctx) const { shared_lock lck(mtx_); @@ -85,20 +57,7 @@ class SharedSyncState { shared_lock lck(mtx_); return isInitialSyncDone(); } - RaftInfo TryTransitRole(RaftInfo expected) { - std::unique_lock lck(mtx_); - if (expected == next_) { - if (current_.role == RaftInfo::Role::Leader && current_.role != next_.role) { - synchronized_.clear(); - initialSyncDoneCnt_ = 0; - } - current_ = next_; - lck.unlock(); - cond_.notify_all(); - return expected; - } - return next_; - } + RaftInfo TryTransitRole(RaftInfo expected); template RaftInfo AwaitRole(bool allowTransitState, const ContextT& ctx) const { shared_lock lck(mtx_); @@ -115,10 +74,7 @@ class SharedSyncState { } return current_; } - void SetRole(RaftInfo info) { - std::lock_guard lck(mtx_); - next_ = info; - } + void SetRole(RaftInfo info); std::pair GetRolesPair() const { shared_lock lck(mtx_); return std::make_pair(current_, next_); @@ -127,14 +83,7 @@ class SharedSyncState { shared_lock lck(mtx_); return current_; } - void SetTerminated() { - { - std::lock_guard lck(mtx_); - terminated_ = true; - next_ = current_ = RaftInfo(); - } - cond_.notify_all(); - } + void SetTerminated(); private: bool isInitialSyncDone(const NamespaceName& name) const { @@ -161,5 +110,4 @@ class SharedSyncState { size_t initialSyncDoneCnt_ = 0; size_t ReplThreadsCnt_ = 0; }; -} // namespace cluster -} // namespace reindexer +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/replication/updatesqueuepair.cc b/cpp_src/cluster/replication/updatesqueuepair.cc new file mode 100644 index 000000000..e959ac4a1 --- /dev/null +++ b/cpp_src/cluster/replication/updatesqueuepair.cc @@ -0,0 +1,124 @@ +#include "updatesqueuepair.h" +#include "cluster/logger.h" +#include "core/formatters/namespacesname_fmt.h" +#include "core/namespace/namespacename.h" + +namespace reindexer::cluster { + +template +UpdatesQueuePair::UpdatesQueuePair(uint64_t maxDataSize) + : syncQueue_(std::make_shared(maxDataSize)), asyncQueue_(std::make_shared(maxDataSize)) {} + +template +typename UpdatesQueuePair::Pair UpdatesQueuePair::GetQueue(const NamespaceName& token) const { + const size_t hash = token.hash(); + Pair result; + shared_lock lck(mtx_); + if (syncQueue_->TokenIsInWhiteList(token, hash)) { + result.sync = syncQueue_; + } + if (asyncQueue_->TokenIsInWhiteList(token, hash)) { + result.async = asyncQueue_; + } + return result; +} + +template +std::shared_ptr::QueueT> UpdatesQueuePair::GetSyncQueue() const { + shared_lock lck(mtx_); + return syncQueue_; +} + +template +std::shared_ptr::QueueT> UpdatesQueuePair::GetAsyncQueue() const { + shared_lock lck(mtx_); + return asyncQueue_; +} + +template +std::pair UpdatesQueuePair::PushNowait(UpdatesContainerT&& data) { + const auto shardPair = GetQueue(data[0].NsName()); + if (shardPair.sync) { + if (shardPair.async) { + shardPair.async->template PushAsync(copyUpdatesContainer(data)); + } + return shardPair.sync->template PushAsync(std::move(data)); + } else if (shardPair.async) { + return shardPair.async->template PushAsync(std::move(data)); + } + return std::make_pair(Error(), false); +} + +template +std::pair UpdatesQueuePair::PushAsync(UpdatesContainerT&& data) { + std::shared_ptr shard; + { + std::string_view token(data[0].NsName()); + const HashT h; + const size_t hash = h(token); + shared_lock lck(mtx_); + if (!asyncQueue_->TokenIsInWhiteList(token, hash)) { + return std::make_pair(Error(), false); + } + shard = asyncQueue_; + } + return shard->template PushAsync(std::move(data)); +} + +template +typename UpdatesQueuePair::UpdatesContainerT UpdatesQueuePair::copyUpdatesContainer(const UpdatesContainerT& data) { + UpdatesContainerT copy; + copy.reserve(data.size()); + for (auto& d : data) { + // async replication should not see emmiter + copy.emplace_back(d.template Clone()); + } + return copy; +} + +template +template +void UpdatesQueuePair::ReinitSyncQueue(ReplicationStatsCollector statsCollector, std::optional&& allowList, + const Logger& l) { + std::lock_guard lck(mtx_); + const auto maxDataSize = syncQueue_->MaxDataSize; + syncQueue_ = std::make_shared(maxDataSize, statsCollector); + syncQueue_->Init(std::move(allowList), &l); +} + +template +template +void UpdatesQueuePair::ReinitAsyncQueue(ReplicationStatsCollector statsCollector, std::optional&& allowList, + const Logger& l) { + std::lock_guard lck(mtx_); + const auto maxDataSize = asyncQueue_->MaxDataSize; + asyncQueue_ = std::make_shared(maxDataSize, statsCollector); + asyncQueue_->Init(std::move(allowList), &l); +} + +template +template +std::pair UpdatesQueuePair::Push(UpdatesContainerT&& data, std::function beforeWait, const ContextT& ctx) { + const auto shardPair = GetQueue(data[0].NsName()); + if (shardPair.sync) { + if (shardPair.async) { + shardPair.async->template PushAsync(copyUpdatesContainer(data)); + } + return shardPair.sync->PushAndWait(std::move(data), std::move(beforeWait), ctx); + } else if (shardPair.async) { + return shardPair.async->template PushAsync(std::move(data)); + } + return std::make_pair(Error(), false); +} + +template class UpdatesQueuePair; +template std::pair UpdatesQueuePair::Push( + typename UpdatesQueuePair::UpdatesContainerT&&, std::function, const RdxContext&); + +template void UpdatesQueuePair::ReinitSyncQueue(ReplicationStatsCollector, std::optional&&, + const Logger&); + +template void UpdatesQueuePair::ReinitAsyncQueue(ReplicationStatsCollector, std::optional&&, + const Logger&); + +} // namespace reindexer::cluster diff --git a/cpp_src/cluster/replication/updatesqueuepair.h b/cpp_src/cluster/replication/updatesqueuepair.h index f3966a236..f6ed669f0 100644 --- a/cpp_src/cluster/replication/updatesqueuepair.h +++ b/cpp_src/cluster/replication/updatesqueuepair.h @@ -1,15 +1,21 @@ #pragma once -#include "cluster/logger.h" #include "cluster/stats/relicationstatscollector.h" -#include "core/namespace/namespacename.h" +#include "updates/updaterecord.h" #include "updates/updatesqueue.h" namespace reindexer { + +class NamespaceName; + namespace cluster { -template +class Logger; + +template class UpdatesQueuePair { + using MtxT = read_write_spinlock; + public: using HashT = nocase_hash_str; using CompareT = nocase_equal_str; @@ -22,98 +28,28 @@ class UpdatesQueuePair { std::shared_ptr async; }; - UpdatesQueuePair(uint64_t maxDataSize) - : syncQueue_(std::make_shared(maxDataSize)), asyncQueue_(std::make_shared(maxDataSize)) {} - - Pair GetQueue(const NamespaceName& token) const { - const size_t hash = token.hash(); - Pair result; - shared_lock lck(mtx_); - if (syncQueue_->TokenIsInWhiteList(token, hash)) { - result.sync = syncQueue_; - } - if (asyncQueue_->TokenIsInWhiteList(token, hash)) { - result.async = asyncQueue_; - } - return result; - } - std::shared_ptr GetSyncQueue() const { - shared_lock lck(mtx_); - return syncQueue_; - } - std::shared_ptr GetAsyncQueue() const { - shared_lock lck(mtx_); - return asyncQueue_; - } + UpdatesQueuePair(uint64_t maxDataSize); + Pair GetQueue(const NamespaceName& token) const; + std::shared_ptr GetSyncQueue() const; + std::shared_ptr GetAsyncQueue() const; template - void ReinitSyncQueue(ReplicationStatsCollector statsCollector, std::optional&& allowList, const Logger& l) { - std::lock_guard lck(mtx_); - const auto maxDataSize = syncQueue_->MaxDataSize; - syncQueue_ = std::make_shared(maxDataSize, statsCollector); - syncQueue_->Init(std::move(allowList), &l); - } + void ReinitSyncQueue(ReplicationStatsCollector statsCollector, std::optional&& allowList, const Logger& l); template - void ReinitAsyncQueue(ReplicationStatsCollector statsCollector, std::optional&& allowList, const Logger& l) { - std::lock_guard lck(mtx_); - const auto maxDataSize = asyncQueue_->MaxDataSize; - asyncQueue_ = std::make_shared(maxDataSize, statsCollector); - asyncQueue_->Init(std::move(allowList), &l); - } + void ReinitAsyncQueue(ReplicationStatsCollector statsCollector, std::optional&& allowList, const Logger& l); template - std::pair Push(UpdatesContainerT&& data, std::function beforeWait, const ContextT& ctx) { - const auto shardPair = GetQueue(data[0].NsName()); - if (shardPair.sync) { - if (shardPair.async) { - shardPair.async->template PushAsync(copyUpdatesContainer(data)); - } - return shardPair.sync->PushAndWait(std::move(data), std::move(beforeWait), ctx); - } else if (shardPair.async) { - return shardPair.async->template PushAsync(std::move(data)); - } - return std::make_pair(Error(), false); - } - std::pair PushNowait(UpdatesContainerT&& data) { - const auto shardPair = GetQueue(data[0].NsName()); - if (shardPair.sync) { - if (shardPair.async) { - shardPair.async->template PushAsync(copyUpdatesContainer(data)); - } - return shardPair.sync->template PushAsync(std::move(data)); - } else if (shardPair.async) { - return shardPair.async->template PushAsync(std::move(data)); - } - return std::make_pair(Error(), false); - } - std::pair PushAsync(UpdatesContainerT&& data) { - std::shared_ptr shard; - { - std::string_view token(data[0].NsName()); - const HashT h; - const size_t hash = h(token); - shared_lock lck(mtx_); - if (!asyncQueue_->TokenIsInWhiteList(token, hash)) { - return std::make_pair(Error(), false); - } - shard = asyncQueue_; - } - return shard->template PushAsync(std::move(data)); - } + std::pair Push(UpdatesContainerT&& data, std::function beforeWait, const ContextT& ctx); + std::pair PushNowait(UpdatesContainerT&& data); + std::pair PushAsync(UpdatesContainerT&& data); private: - UpdatesContainerT copyUpdatesContainer(const UpdatesContainerT& data) { - UpdatesContainerT copy; - copy.reserve(data.size()); - for (auto& d : data) { - // async replication should not see emmiter - copy.emplace_back(d.template Clone()); - } - return copy; - } + UpdatesContainerT copyUpdatesContainer(const UpdatesContainerT& data); mutable MtxT mtx_; std::shared_ptr syncQueue_; std::shared_ptr asyncQueue_; }; +extern template class UpdatesQueuePair; + } // namespace cluster } // namespace reindexer diff --git a/cpp_src/cluster/sharding/sharding.cc b/cpp_src/cluster/sharding/sharding.cc index acc3f8c27..b7fafea6a 100644 --- a/cpp_src/cluster/sharding/sharding.cc +++ b/cpp_src/cluster/sharding/sharding.cc @@ -1,9 +1,10 @@ #include "sharding.h" +#include "cluster/consts.h" #include "cluster/stats/replicationstats.h" #include "core/clusterproxy.h" -#include "core/defnsconfigs.h" #include "core/item.h" #include "core/type_consts.h" +#include "estl/gift_str.h" #include "tools/logger.h" namespace reindexer { @@ -20,7 +21,7 @@ bool RoutingStrategy::getHostIdForQuery(const Query& q, int& hostId, Variant& sh for (auto it = q.Entries().cbegin(), next = it, end = q.Entries().cend(); it != end; ++it) { ++next; it->Visit( - Skip{}, + Skip{}, [&](const QueryEntry& qe) { if (containsKey) { if (keys_.IsShardIndex(ns, qe.FieldName())) { @@ -63,7 +64,7 @@ bool RoutingStrategy::getHostIdForQuery(const Query& q, int& hostId, Variant& sh [&](const Bracket&) { for (auto i = it.cbegin().PlainIterator(), end = it.cend().PlainIterator(); i != end; ++i) { i->Visit( - Skip{}, + Skip{}, [&](const QueryEntry& qe) { if (keys_.IsShardIndex(ns, qe.FieldName())) { throw Error(errLogic, "Shard key condition cannot be included in bracket"); @@ -355,8 +356,8 @@ LocatorService::LocatorService(ClusterProxy& rx, cluster::ShardingConfig config) Error LocatorService::convertShardingKeysValues(KeyValueType fieldType, std::vector& keys) { return fieldType.EvaluateOneOf( - [&](OneOf) -> Error { + [&](OneOf) -> Error { try { for (auto& k : keys) { for (auto& [l, r, _] : k.values) { @@ -371,7 +372,7 @@ Error LocatorService::convertShardingKeysValues(KeyValueType fieldType, std::vec return Error(); }, [](OneOf) { return Error{errLogic, "Sharding by composite index is unsupported"}; }, - [fieldType](OneOf) { + [fieldType](OneOf) { return Error{errLogic, "Unsupported field type: %s", fieldType.Name()}; }); } @@ -594,5 +595,26 @@ std::shared_ptr LocatorService::getShardConnection(int shardI return peekHostForShard(hostsConnections_[shardId], shardId, status); } +Connections::Connections(Connections&& obj) noexcept + : base(static_cast(obj)), + actualIndex(std::move(obj.actualIndex)), + reconnectTs(obj.reconnectTs), + status(std::move(obj.status)), + shutdown(obj.shutdown) {} + +Connections::Connections(const Connections& obj) noexcept + : base(obj), actualIndex(obj.actualIndex), reconnectTs(obj.reconnectTs), status(obj.status), shutdown(obj.shutdown) {} + +void Connections::Shutdown() { + std::lock_guard lck(m); + if (!shutdown) { + for (auto& conn : *this) { + conn->Stop(); + } + shutdown = true; + status = Error(errTerminated, "Sharding proxy is already shut down"); + } +} + } // namespace sharding } // namespace reindexer diff --git a/cpp_src/cluster/sharding/sharding.h b/cpp_src/cluster/sharding/sharding.h index baffd438e..bd70e5734 100644 --- a/cpp_src/cluster/sharding/sharding.h +++ b/cpp_src/cluster/sharding/sharding.h @@ -42,27 +42,11 @@ class RoutingStrategy { class Connections : public std::vector> { public: using base = std::vector>; - Connections() : base() {} - Connections(Connections&& obj) noexcept - : base(static_cast(obj)), - actualIndex(std::move(obj.actualIndex)), - reconnectTs(obj.reconnectTs), - status(std::move(obj.status)), - shutdown(obj.shutdown) {} - - Connections(const Connections& obj) noexcept - : base(obj), actualIndex(obj.actualIndex), reconnectTs(obj.reconnectTs), status(obj.status), shutdown(obj.shutdown) {} - - void Shutdown() { - std::lock_guard lck(m); - if (!shutdown) { - for (auto& conn : *this) { - conn->Stop(); - } - shutdown = true; - status = Error(errTerminated, "Sharding proxy is already shut down"); - } - } + Connections() = default; + Connections(Connections&& obj) noexcept; + Connections(const Connections& obj) noexcept; + + void Shutdown(); shared_timed_mutex m; std::optional actualIndex; steady_clock_w::time_point reconnectTs; diff --git a/cpp_src/cluster/sharding/shardingcontrolrequest.cc b/cpp_src/cluster/sharding/shardingcontrolrequest.cc index 399c551ac..e85c814f9 100644 --- a/cpp_src/cluster/sharding/shardingcontrolrequest.cc +++ b/cpp_src/cluster/sharding/shardingcontrolrequest.cc @@ -1,6 +1,7 @@ #include "shardingcontrolrequest.h" #include "core/cjson/jsonbuilder.h" #include "tools/catch_and_return.h" +#include "vendor/gason/gason.h" namespace reindexer::sharding { @@ -15,7 +16,7 @@ static void getJSON(const ControlDataT& shardingControl, WrSerializer& ser) { } template -static Error fromJSON(ControlDataT& shardingControl, span json) noexcept { +static Error fromJSON(ControlDataT& shardingControl, std::span json) noexcept { try { gason::JsonParser parser; auto node = parser.Parse(json); @@ -29,8 +30,8 @@ static Error fromJSON(ControlDataT& shardingControl, span json) noexcept { void ShardingControlRequestData::GetJSON(WrSerializer& ser) const { return getJSON(*this, ser); } void ShardingControlResponseData::GetJSON(WrSerializer& ser) const { return getJSON(*this, ser); } -Error ShardingControlRequestData::FromJSON(span json) noexcept { return fromJSON(*this, json); } -Error ShardingControlResponseData::FromJSON(span json) noexcept { return fromJSON(*this, json); } +Error ShardingControlRequestData::FromJSON(std::span json) noexcept { return fromJSON(*this, json); } +Error ShardingControlResponseData::FromJSON(std::span json) noexcept { return fromJSON(*this, json); } void ApplyLeaderConfigCommand::GetJSON(JsonBuilder& json) const { json.Put("config", config); @@ -66,7 +67,8 @@ void ResetConfigCommand::FromJSON(const gason::JsonNode& payload) { sourceId = p void GetNodeConfigCommand::GetJSON(JsonBuilder& json) const { json.Put("config", config.GetJSON(cluster::MaskingDSN(masking))); } void GetNodeConfigCommand::FromJSON(const gason::JsonNode& payload) { - auto err = config.FromJSON(payload["config"].As()); + auto cfg = payload["config"].As(); + auto err = config.FromJSON(std::span(cfg)); if (!err.ok()) { throw err; } diff --git a/cpp_src/cluster/sharding/shardingcontrolrequest.h b/cpp_src/cluster/sharding/shardingcontrolrequest.h index 25f0ff123..53fb41b15 100644 --- a/cpp_src/cluster/sharding/shardingcontrolrequest.h +++ b/cpp_src/cluster/sharding/shardingcontrolrequest.h @@ -92,7 +92,7 @@ void assign_if_constructible(T& data, Args&&... args) { struct ShardingControlRequestData { ShardingControlRequestData() noexcept = default; - Error FromJSON(span json) noexcept; + Error FromJSON(std::span json) noexcept; void GetJSON(WrSerializer& ser) const; template @@ -126,7 +126,7 @@ struct ShardingControlRequestData { struct ShardingControlResponseData { ShardingControlResponseData() noexcept = default; - Error FromJSON(span json) noexcept; + Error FromJSON(std::span json) noexcept; void GetJSON(WrSerializer& ser) const; template diff --git a/cpp_src/cluster/stats/relicationstatscollector.h b/cpp_src/cluster/stats/relicationstatscollector.h index 3fc42f337..58ced4efa 100644 --- a/cpp_src/cluster/stats/relicationstatscollector.h +++ b/cpp_src/cluster/stats/relicationstatscollector.h @@ -1,6 +1,7 @@ #pragma once #include "replicationstats.h" +#include "tools/clock.h" namespace reindexer { namespace cluster { diff --git a/cpp_src/cluster/stats/replicationstats.cc b/cpp_src/cluster/stats/replicationstats.cc index 358af15ac..de1509b25 100644 --- a/cpp_src/cluster/stats/replicationstats.cc +++ b/cpp_src/cluster/stats/replicationstats.cc @@ -1,5 +1,7 @@ #include "replicationstats.h" +#include "cluster/consts.h" #include "core/cjson/jsonbuilder.h" +#include "vendor/gason/gason.h" namespace reindexer { namespace cluster { @@ -135,7 +137,7 @@ void NodeStats::GetJSON(JsonBuilder& builder) const { } } -Error ReplicationStats::FromJSON(span json) { +Error ReplicationStats::FromJSON(std::span json) { try { gason::JsonParser parser; return FromJSON(parser.Parse(json)); @@ -206,6 +208,22 @@ void ReplicationStats::GetJSON(WrSerializer& ser) const { GetJSON(jb); } +void SyncStatsCounter::Hit(std::chrono::microseconds time) noexcept { + std::lock_guard lck(mtx_); + totalTimeUs += time.count(); + ++count; + if (maxTimeUs < time.count()) { + maxTimeUs = time.count(); + } +} + +void SyncStatsCounter::Reset() noexcept { + std::lock_guard lck(mtx_); + count = 0; + maxTimeUs = 0; + totalTimeUs = 0; +} + SyncStats SyncStatsCounter::Get() const { SyncStats stats; { @@ -217,6 +235,23 @@ SyncStats SyncStatsCounter::Get() const { return stats; } +void NodeStatsCounter::SaveLastError(const Error& err) noexcept { + std::lock_guard lck(mtx_); + lastError = err; +} + +Error NodeStatsCounter::GetLastError() const { + std::lock_guard lck(mtx_); + return lastError; +} + +void NodeStatsCounter::Reset() noexcept { + status.store(NodeStats::Status::None, std::memory_order_relaxed); + syncState.store(NodeStats::SyncState::None, std::memory_order_relaxed); + lastAppliedUpdateId_.store(-1, std::memory_order_relaxed); + SaveLastError(Error()); +} + NodeStats NodeStatsCounter::Get() const { NodeStats stats; stats.dsn = dsn; @@ -229,6 +264,53 @@ NodeStats NodeStatsCounter::Get() const { return stats; } +void ReplicationStatCounter::OnStatusChanged(size_t nodeId, NodeStats::Status status) const noexcept { + shared_lock rlck(mtx_); + auto found = nodeCounters_.find(nodeId); + if (found != nodeCounters_.end()) { + found->second->OnStatusChanged(status); + } +} + +void ReplicationStatCounter::OnSyncStateChanged(size_t nodeId, NodeStats::SyncState state) noexcept { + shared_lock rlck(mtx_); + if (nodeId == kLeaderUID && thisNode_.has_value()) { + thisNode_->OnSyncStateChanged(state); + } else { + auto found = nodeCounters_.find(nodeId); + if (found != nodeCounters_.end()) { + found->second->OnSyncStateChanged(state); + } + } +} + +void ReplicationStatCounter::OnServerIdChanged(size_t nodeId, int serverId) const noexcept { + shared_lock rlck(mtx_); + auto found = nodeCounters_.find(nodeId); + if (found != nodeCounters_.end()) { + found->second->OnServerIdChanged(serverId); + } +} + +void ReplicationStatCounter::SaveNodeError(size_t nodeId, const Error& lastError) noexcept { + shared_lock rlck(mtx_); + auto found = nodeCounters_.find(nodeId); + if (found != nodeCounters_.end()) { + found->second->SaveLastError(lastError); + } +} + +void ReplicationStatCounter::Reset() noexcept { + walSyncs_.Reset(); + forceSyncs_.Reset(); + initialForceSyncs_.Reset(); + initialWalSyncs_.Reset(); + std::lock_guard lck(mtx_); + for (auto& node : nodeCounters_) { + node.second->Reset(); + } +} + ReplicationStats ReplicationStatCounter::Get() const { ReplicationStats stats; stats.type = type_; diff --git a/cpp_src/cluster/stats/replicationstats.h b/cpp_src/cluster/stats/replicationstats.h index 697bb7a57..5529d2990 100644 --- a/cpp_src/cluster/stats/replicationstats.h +++ b/cpp_src/cluster/stats/replicationstats.h @@ -1,13 +1,11 @@ #pragma once +#include #include #include "cluster/config.h" -#include "cluster/consts.h" -#include "core/perfstatcounter.h" -#include "core/transaction/transaction.h" #include "estl/fast_hash_map.h" +#include "estl/mutex.h" #include "estl/shared_mutex.h" -#include "tools/stringstools.h" namespace reindexer { namespace cluster { @@ -60,7 +58,7 @@ struct NodeStats { }; struct ReplicationStats { - Error FromJSON(span json); + Error FromJSON(std::span json); Error FromJSON(const gason::JsonNode& root); void GetJSON(JsonBuilder& builder) const; void GetJSON(WrSerializer& ser) const; @@ -85,20 +83,8 @@ struct ReplicationStats { }; struct SyncStatsCounter { - void Hit(std::chrono::microseconds time) noexcept { - std::lock_guard lck(mtx_); - totalTimeUs += time.count(); - ++count; - if (maxTimeUs < time.count()) { - maxTimeUs = time.count(); - } - } - void Reset() noexcept { - std::lock_guard lck(mtx_); - count = 0; - maxTimeUs = 0; - totalTimeUs = 0; - } + void Hit(std::chrono::microseconds time) noexcept; + void Reset() noexcept; SyncStats Get() const; size_t count = 0; @@ -113,20 +99,9 @@ struct NodeStatsCounter { void OnStatusChanged(NodeStats::Status st) noexcept { status.store(st, std::memory_order_relaxed); } void OnSyncStateChanged(NodeStats::SyncState st) noexcept { syncState.store(st, std::memory_order_relaxed); } void OnServerIdChanged(int sId) noexcept { serverId.store(sId, std::memory_order_relaxed); } - void SaveLastError(const Error& err) { - std::lock_guard lck(mtx_); - lastError = err; - } - Error GetLastError() const { - std::lock_guard lck(mtx_); - return lastError; - } - void Reset() noexcept { - status.store(NodeStats::Status::None, std::memory_order_relaxed); - syncState.store(NodeStats::SyncState::None, std::memory_order_relaxed); - lastAppliedUpdateId_.store(-1, std::memory_order_relaxed); - SaveLastError(Error()); - } + void SaveLastError(const Error& err) noexcept; + Error GetLastError() const; + void Reset() noexcept; NodeStats Get() const; const DSN dsn; @@ -195,48 +170,11 @@ class ReplicationStatCounter { lastErasedUpdateId_.store(updateId, std::memory_order_relaxed); allocatedUpdatesSizeBytes_.fetch_sub(size, std::memory_order_relaxed); } - void OnStatusChanged(size_t nodeId, NodeStats::Status status) const noexcept { - shared_lock rlck(mtx_); - auto found = nodeCounters_.find(nodeId); - if (found != nodeCounters_.end()) { - found->second->OnStatusChanged(status); - } - } - void OnSyncStateChanged(size_t nodeId, NodeStats::SyncState state) { - shared_lock rlck(mtx_); - if (nodeId == kLeaderUID && thisNode_.has_value()) { - thisNode_->OnSyncStateChanged(state); - } else { - auto found = nodeCounters_.find(nodeId); - if (found != nodeCounters_.end()) { - found->second->OnSyncStateChanged(state); - } - } - } - void OnServerIdChanged(size_t nodeId, int serverId) const noexcept { - shared_lock rlck(mtx_); - auto found = nodeCounters_.find(nodeId); - if (found != nodeCounters_.end()) { - found->second->OnServerIdChanged(serverId); - } - } - void SaveNodeError(size_t nodeId, const Error& lastError) { - shared_lock rlck(mtx_); - auto found = nodeCounters_.find(nodeId); - if (found != nodeCounters_.end()) { - found->second->SaveLastError(lastError); - } - } - void Reset() noexcept { - walSyncs_.Reset(); - forceSyncs_.Reset(); - initialForceSyncs_.Reset(); - initialWalSyncs_.Reset(); - std::lock_guard lck(mtx_); - for (auto& node : nodeCounters_) { - node.second->Reset(); - } - } + void OnStatusChanged(size_t nodeId, NodeStats::Status status) const noexcept; + void OnSyncStateChanged(size_t nodeId, NodeStats::SyncState state) noexcept; + void OnServerIdChanged(size_t nodeId, int serverId) const noexcept; + void SaveNodeError(size_t nodeId, const Error& lastError) noexcept; + void Reset() noexcept; ReplicationStats Get() const; private: diff --git a/cpp_src/cmake/modules/FindMKL.cmake b/cpp_src/cmake/modules/FindMKL.cmake new file mode 100644 index 000000000..460b86ad5 --- /dev/null +++ b/cpp_src/cmake/modules/FindMKL.cmake @@ -0,0 +1,363 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from CMake's FindBLAS module. +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. + +#[=======================================================================[.rst: +FindMKL +-------- + +Find Intel MKL library. + +Input Variables +^^^^^^^^^^^^^^^ + +The following variables may be set to influence this module's behavior: + +``BLA_STATIC`` + if ``ON`` use static linkage + +``BLA_VENDOR`` + If set, checks only the specified vendor, if not set checks all the + possibilities. List of vendors valid in this module: + + * ``Intel10_32`` (intel mkl v10 32 bit) + * ``Intel10_64lp`` (intel mkl v10+ 64 bit, threaded code, lp64 model) + * ``Intel10_64lp_seq`` (intel mkl v10+ 64 bit, sequential code, lp64 model) + * ``Intel10_64ilp`` (intel mkl v10+ 64 bit, threaded code, ilp64 model) + * ``Intel10_64ilp_seq`` (intel mkl v10+ 64 bit, sequential code, ilp64 model) + * ``Intel10_64_dyn`` (intel mkl v10+ 64 bit, single dynamic library) + * ``Intel`` (obsolete versions of mkl 32 and 64 bit) + + +Result Variables +^^^^^^^^^^^^^^^^ + +This module defines the following variables: + +``MKL_FOUND`` + library implementing the BLAS interface is found +``MKL_LIBRARIES`` + uncached list of libraries (using full path name) to link against + to use MKL (may be empty if compiler implicitly links MKL) + +.. note:: + + C or CXX must be enabled to use Intel Math Kernel Library (MKL). + + For example, to use Intel MKL libraries and/or Intel compiler: + + .. code-block:: cmake + + set(BLA_VENDOR Intel10_64lp) + find_package(MKL) + +Hints +^^^^^ + +Set the ``MKLROOT`` environment variable to a directory that contains an MKL +installation, or add the directory to the dynamic library loader environment +variable for your platform (``LIB``, ``DYLD_LIBRARY_PATH`` or +``LD_LIBRARY_PATH``). + +#]=======================================================================] + +include(CheckFunctionExists) +include(CMakePushCheckState) +include(FindPackageHandleStandardArgs) +cmake_push_check_state() +set(CMAKE_REQUIRED_QUIET ${BLAS_FIND_QUIETLY}) + + +set(_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES}) +if(BLA_STATIC) + if(WIN32) + set(CMAKE_FIND_LIBRARY_SUFFIXES .lib ${CMAKE_FIND_LIBRARY_SUFFIXES}) + else() + set(CMAKE_FIND_LIBRARY_SUFFIXES .a ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + # for ubuntu's libblas3gf and liblapack3gf packages + set(CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES} .so.3gf) + endif() +endif() + +macro(CHECK_BLAS_LIBRARIES LIBRARIES _prefix _name _flags _list _threadlibs _addlibdir _subdirs) + # This macro checks for the existence of the combination of fortran libraries + # given by _list. If the combination is found, this macro checks (using the + # Check_Fortran_Function_Exists macro) whether can link against that library + # combination using the name of a routine given by _name using the linker + # flags given by _flags. If the combination of libraries is found and passes + # the link test, LIBRARIES is set to the list of complete library paths that + # have been found. Otherwise, LIBRARIES is set to FALSE. + + # N.B. _prefix is the prefix applied to the names of all cached variables that + # are generated internally and marked advanced by this macro. + # _addlibdir is a list of additional search paths. _subdirs is a list of path + # suffixes to be used by find_library(). + + set(_libraries_work TRUE) + set(${LIBRARIES}) + set(_combined_name) + + set(_extaddlibdir "${_addlibdir}") + if(WIN32) + list(APPEND _extaddlibdir ENV LIB) + elseif(APPLE) + list(APPEND _extaddlibdir ENV DYLD_LIBRARY_PATH) + else() + list(APPEND _extaddlibdir ENV LD_LIBRARY_PATH) + endif() + list(APPEND _extaddlibdir "${CMAKE_C_IMPLICIT_LINK_DIRECTORIES}") + + foreach(_library ${_list}) + if(_library MATCHES "^-Wl,--(start|end)-group$") + # Respect linker flags like --start/end-group (required by MKL) + set(${LIBRARIES} ${${LIBRARIES}} "${_library}") + else() + set(_combined_name ${_combined_name}_${_library}) + if(NOT "${_threadlibs}" STREQUAL "") + set(_combined_name ${_combined_name}_threadlibs) + endif() + if(_libraries_work) + find_library(${_prefix}_${_library}_LIBRARY + NAMES ${_library} + PATHS ${_extaddlibdir} + PATH_SUFFIXES ${_subdirs} + ) + #message("DEBUG: find_library(${_library}) got ${${_prefix}_${_library}_LIBRARY}") + mark_as_advanced(${_prefix}_${_library}_LIBRARY) + set(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY}) + set(_libraries_work ${${_prefix}_${_library}_LIBRARY}) + endif() + endif() + endforeach() + + if(_libraries_work) + # Test this combination of libraries. + set(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}} ${_threadlibs}) + #message("DEBUG: CMAKE_REQUIRED_LIBRARIES = ${CMAKE_REQUIRED_LIBRARIES}") + if(CMAKE_Fortran_COMPILER_LOADED) + check_fortran_function_exists("${_name}" ${_prefix}${_combined_name}_WORKS) + else() + check_function_exists("${_name}_" ${_prefix}${_combined_name}_WORKS) + endif() + set(CMAKE_REQUIRED_LIBRARIES) + set(_libraries_work ${${_prefix}${_combined_name}_WORKS}) + endif() + + if(_libraries_work) + if("${_list}" STREQUAL "") + set(${LIBRARIES} "${LIBRARIES}-PLACEHOLDER-FOR-EMPTY-LIBRARIES") + else() + set(${LIBRARIES} ${${LIBRARIES}} ${_threadlibs}) + endif() + else() + set(${LIBRARIES} FALSE) + endif() + #message("DEBUG: ${LIBRARIES} = ${${LIBRARIES}}") +endmacro() + +set(MKL_LIBRARIES) +if(NOT $ENV{BLA_VENDOR} STREQUAL "") + set(BLA_VENDOR $ENV{BLA_VENDOR}) +else() + if(NOT BLA_VENDOR) + set(BLA_VENDOR "All") + endif() +endif() +if(BLA_VENDOR_THREADING) + set(BLAS_mkl_THREADING ${BLA_VENDOR_THREADING}) +else() + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(BLAS_mkl_THREADING "gnu") + else() + set(BLAS_mkl_THREADING "intel") + endif() +endif() + +if(CMAKE_C_COMPILER_LOADED OR CMAKE_CXX_COMPILER_LOADED) + # System-specific settings + if(WIN32) + if(BLA_STATIC) + set(BLAS_mkl_DLL_SUFFIX "") + else() + set(BLAS_mkl_DLL_SUFFIX "_dll") + endif() + else() + if(BLA_STATIC) + set(BLAS_mkl_START_GROUP "-Wl,--start-group") + set(BLAS_mkl_END_GROUP "-Wl,--end-group") + else() + set(BLAS_mkl_START_GROUP "") + set(BLAS_mkl_END_GROUP "") + endif() + if(BLAS_mkl_THREADING STREQUAL "gnu") + set(BLAS_mkl_OMP "gomp") + else() + set(BLAS_mkl_OMP "iomp5") + endif() + set(BLAS_mkl_LM "-lm") + set(BLAS_mkl_LDL "-ldl") + endif() + + if(BLAS_FIND_QUIETLY OR NOT BLAS_FIND_REQUIRED) + find_package(Threads) + else() + find_package(Threads REQUIRED) + endif() + + set(BLAS_mkl_INTFACE "intel") + if(BLA_VENDOR MATCHES "_64ilp") + set(BLAS_mkl_ILP_MODE "ilp64") + else() + set(BLAS_mkl_ILP_MODE "lp64") + endif() + + set(BLAS_SEARCH_LIBS "") + + set(BLAS_mkl_SEARCH_SYMBOL sgemm) + set(_LIBRARIES MKL_LIBRARIES) + if(WIN32) + # Find the main file (32-bit or 64-bit) + set(BLAS_SEARCH_LIBS_WIN_MAIN "") + if(BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN + "mkl_intel_c${BLAS_mkl_DLL_SUFFIX}") + endif() + if(BLA_VENDOR MATCHES "^Intel10_64i?lp" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_MAIN + "mkl_intel_${BLAS_mkl_ILP_MODE}${BLAS_mkl_DLL_SUFFIX}") + endif() + + # Add threading/sequential libs + set(BLAS_SEARCH_LIBS_WIN_THREAD "") + if(BLA_VENDOR MATCHES "^Intel10_64i?lp$" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "libguide40 mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "libiomp5md mkl_intel_thread${BLAS_mkl_DLL_SUFFIX}") + endif() + if(BLA_VENDOR MATCHES "^Intel10_64i?lp_seq$" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS_WIN_THREAD + "mkl_sequential${BLAS_mkl_DLL_SUFFIX}") + endif() + + # Cartesian product of the above + foreach(MAIN ${BLAS_SEARCH_LIBS_WIN_MAIN}) + foreach(THREAD ${BLAS_SEARCH_LIBS_WIN_THREAD}) + list(APPEND BLAS_SEARCH_LIBS + "${MAIN} ${THREAD} mkl_core${BLAS_mkl_DLL_SUFFIX}") + endforeach() + endforeach() + else() + if(BLA_VENDOR STREQUAL "Intel10_32" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS + "mkl_${BLAS_mkl_INTFACE} mkl_${BLAS_mkl_THREADING}_thread mkl_core guide") + + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS + "${BLAS_mkl_START_GROUP} mkl_${BLAS_mkl_INTFACE} mkl_${BLAS_mkl_THREADING}_thread mkl_core ${BLAS_mkl_END_GROUP} ${BLAS_mkl_OMP}") + endif() + if(BLA_VENDOR MATCHES "^Intel10_64i?lp$" OR BLA_VENDOR STREQUAL "All") + # old version + list(APPEND BLAS_SEARCH_LIBS + "mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_THREADING}_thread mkl_core guide") + + # mkl >= 10.3 + list(APPEND BLAS_SEARCH_LIBS + "${BLAS_mkl_START_GROUP} mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_${BLAS_mkl_THREADING}_thread mkl_core ${BLAS_mkl_END_GROUP} ${BLAS_mkl_OMP}") + endif() + if(BLA_VENDOR MATCHES "^Intel10_64i?lp_seq$" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS + "${BLAS_mkl_START_GROUP} mkl_${BLAS_mkl_INTFACE}_${BLAS_mkl_ILP_MODE} mkl_sequential mkl_core ${BLAS_mkl_END_GROUP}") + endif() + + #older vesions of intel mkl libs + if(BLA_VENDOR STREQUAL "Intel" OR BLA_VENDOR STREQUAL "All") + list(APPEND BLAS_SEARCH_LIBS + "mkl") + list(APPEND BLAS_SEARCH_LIBS + "mkl_ia32") + list(APPEND BLAS_SEARCH_LIBS + "mkl_em64t") + endif() + endif() + + if(BLA_VENDOR MATCHES "^Intel10_64_dyn$" OR BLA_VENDOR STREQUAL "All") + # mkl >= 10.3 with single dynamic library + list(APPEND BLAS_SEARCH_LIBS + "mkl_rt") + endif() + + # MKL uses a multitude of partially platform-specific subdirectories: + if(BLA_VENDOR STREQUAL "Intel10_32") + set(BLAS_mkl_ARCH_NAME "ia32") + else() + set(BLAS_mkl_ARCH_NAME "intel64") + endif() + if(WIN32) + set(BLAS_mkl_OS_NAME "win") + elseif(APPLE) + set(BLAS_mkl_OS_NAME "mac") + else() + set(BLAS_mkl_OS_NAME "lin") + endif() + if(DEFINED ENV{MKLROOT}) + file(TO_CMAKE_PATH "$ENV{MKLROOT}" BLAS_mkl_MKLROOT) + # If MKLROOT points to the subdirectory 'mkl', use the parent directory instead + # so we can better detect other relevant libraries in 'compiler' or 'tbb': + get_filename_component(BLAS_mkl_MKLROOT_LAST_DIR "${BLAS_mkl_MKLROOT}" NAME) + if(BLAS_mkl_MKLROOT_LAST_DIR STREQUAL "mkl") + get_filename_component(BLAS_mkl_MKLROOT "${BLAS_mkl_MKLROOT}" DIRECTORY) + endif() + endif() + set(BLAS_mkl_LIB_PATH_SUFFIXES + "compiler/lib" "compiler/lib/${BLAS_mkl_ARCH_NAME}_${BLAS_mkl_OS_NAME}" + "mkl/lib" "mkl/lib/${BLAS_mkl_ARCH_NAME}_${BLAS_mkl_OS_NAME}" + "lib/${BLAS_mkl_ARCH_NAME}_${BLAS_mkl_OS_NAME}") + + foreach(IT ${BLAS_SEARCH_LIBS}) + string(REPLACE " " ";" SEARCH_LIBS ${IT}) + if(NOT ${_LIBRARIES}) + check_blas_libraries( + ${_LIBRARIES} + BLAS + ${BLAS_mkl_SEARCH_SYMBOL} + "" + "${SEARCH_LIBS}" + "${CMAKE_THREAD_LIBS_INIT};${BLAS_mkl_LM};${BLAS_mkl_LDL}" + "${BLAS_mkl_MKLROOT}" + "${BLAS_mkl_LIB_PATH_SUFFIXES}" + ) + endif() + endforeach() + + unset(BLAS_mkl_ILP_MODE) + unset(BLAS_mkl_INTFACE) + unset(BLAS_mkl_THREADING) + unset(BLAS_mkl_OMP) + unset(BLAS_mkl_DLL_SUFFIX) + unset(BLAS_mkl_LM) + unset(BLAS_mkl_LDL) + unset(BLAS_mkl_MKLROOT) + unset(BLAS_mkl_MKLROOT_LAST_DIR) + unset(BLAS_mkl_ARCH_NAME) + unset(BLAS_mkl_OS_NAME) + unset(BLAS_mkl_LIB_PATH_SUFFIXES) +endif() + + +find_package_handle_standard_args(MKL REQUIRED_VARS MKL_LIBRARIES) + +cmake_pop_check_state() +set(CMAKE_FIND_LIBRARY_SUFFIXES ${_blas_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES}) diff --git a/cpp_src/cmake/modules/RxPrepareCpackDeps.cmake b/cpp_src/cmake/modules/RxPrepareCpackDeps.cmake index 59118068c..4823f5d05 100644 --- a/cpp_src/cmake/modules/RxPrepareCpackDeps.cmake +++ b/cpp_src/cmake/modules/RxPrepareCpackDeps.cmake @@ -24,7 +24,7 @@ endif() message("Target cpack package type was detected as '${RxPrepareCpackDeps}'") -SET(CPACK_PACKAGE_NAME "reindexer-4") +SET(CPACK_PACKAGE_NAME "reindexer") SET(CPACK_PACKAGE_DESCRIPTION_SUMMARY "ReindexerDB server package") SET(CPACK_PACKAGE_VENDOR "Reindexer") SET(CPACK_PACKAGE_CONTACT "Reindexer team ") @@ -97,6 +97,15 @@ else() set (CPACK_DEBIAN_DEV_PACKAGE_DEPENDS "libleveldb-dev,${CPACK_DEBIAN_PACKAGE_DEPENDS}") endif() +SET(CPACK_DEBIAN_PACKAGE_DEPENDS "${CPACK_DEBIAN_PACKAGE_DEPENDS},libopenblas-openmp-dev") +if (LINUX_ISSUE MATCHES "altlinux") + SET(CPACK_RPM_PACKAGE_REQUIRES_PRE "${CPACK_RPM_PACKAGE_REQUIRES_PRE},libgomp") + SET(CPACK_RPM_PACKAGE_REQUIRES_PRE "${CPACK_RPM_PACKAGE_REQUIRES_PRE},liblapack") + SET(CPACK_RPM_PACKAGE_REQUIRES_PRE "${CPACK_RPM_PACKAGE_REQUIRES_PRE},libopenblas") +else() + SET(CPACK_RPM_PACKAGE_REQUIRES_PRE "${CPACK_RPM_PACKAGE_REQUIRES_PRE},openblas") +endif() + if (CPACK_RPM_PACKAGE_REQUIRES_PRE STREQUAL "") set (CPACK_RPM_DEV_PACKAGE_REQUIRES_PRE "${RPM_EXTRA_LIB_PREFIX}leveldb-devel") else() diff --git a/cpp_src/cmake/modules/RxPrepareInstallFiles.cmake b/cpp_src/cmake/modules/RxPrepareInstallFiles.cmake index ff0fded50..cbe36a050 100644 --- a/cpp_src/cmake/modules/RxPrepareInstallFiles.cmake +++ b/cpp_src/cmake/modules/RxPrepareInstallFiles.cmake @@ -24,13 +24,17 @@ endif() if (${lib} MATCHES "jemalloc" OR ${lib} MATCHES "tcmalloc") elseif(${lib} STREQUAL "-pthread") list(APPEND flibs " -lpthread") - elseif("${lib}" MATCHES "^\\-.*") + elseif(${lib} MATCHES "^\\-.*") list(APPEND flibs " ${lib}") else() if (NOT "${lib}" STREQUAL "snappy" OR SNAPPY_FOUND) - get_filename_component(lib ${lib} NAME_WE) - string(REGEX REPLACE "^lib" "" lib ${lib}) - list(APPEND flibs " -l${lib}") + get_filename_component(lib_name ${lib} NAME_WE) + string(REGEX REPLACE "^lib" "" lib_name ${lib_name}) + if (${lib} MATCHES "framework") + list(APPEND flibs " -framework ${lib_name}") + else() + list(APPEND flibs " -l${lib_name}") + endif() else() list(APPEND flibs " -l${lib}") endif() @@ -67,19 +71,21 @@ if (NOT WIN32) "core/namespacedef.h" "core/keyvalue/variant.h" "core/keyvalue/geometry.h" "core/sortingprioritiestable.h" "core/rdxcontext.h" "core/activity_context.h" "core/activity.h" "core/activitylog.h" "core/type_consts_helpers.h" "core/payload/fieldsset.h" "core/payload/payloadtype.h" "core/cbinding/reindexer_c.h" "core/cbinding/reindexer_ctypes.h" "core/transaction/transaction.h" "core/payload/payloadfieldtype.h" "core/reindexerconfig.h" - "core/query/query.h" "core/query/queryentry.h" "core/queryresults/queryresults.h" "core/indexdef.h" "core/queryresults/aggregationresult.h" + "core/query/query.h" "core/query/queryentry.h" "core/queryresults/queryresults.h" "core/query/knn_search_params.h" "core/indexdef.h" "core/queryresults/aggregationresult.h" "core/queryresults/itemref.h" "core/namespace/stringsholder.h" "core/keyvalue/key_string.h" "core/keyvalue/uuid.h" "core/key_value_type.h" - "core/namespace/incarnationtags.h" "core/keyvalue/p_string.h" - "core/itemimplrawdata.h" "core/expressiontree.h" "tools/lsn.h" "core/cjson/tagspath.h" "core/cjson/ctag.h" + "core/namespace/incarnationtags.h" "core/keyvalue/p_string.h" "core/keyvalue/float_vector.h" "core/enums.h" "core/keyvalue/float_vectors_holder.h" "core/namespace/float_vectors_indexes.h" + "core/itemimplrawdata.h" "core/expressiontree.h" "tools/lsn.h" "core/cjson/tagspath.h" "core/cjson/ctag.h" "core/rank_t.h" "core/system_ns_names.h" "estl/cow.h" "core/shardedmeta.h" "estl/overloaded.h" "estl/one_of.h" "core/queryresults/localqueryresults.h" - "estl/h_vector.h" "estl/mutex.h" "estl/intrusive_ptr.h" "estl/trivial_reverse_iterator.h" "estl/span.h" "estl/chunk.h" + "estl/h_vector.h" "estl/mutex.h" "estl/intrusive_ptr.h" "estl/trivial_reverse_iterator.h" "estl/chunk.h" "estl/expected.h" + "estl/fast_hash_map.h" "vendor/hopscotch/hopscotch_map.h" "vendor/hopscotch/hopscotch_sc_map.h" "vendor/hopscotch/hopscotch_hash.h" "estl/elist.h" "estl/fast_hash_traits.h" "estl/debug_macros.h" "estl/defines.h" "estl/template.h" "estl/comparation_result.h" "client/item.h" "client/resultserializer.h" "client/internalrdxcontext.h" "client/reindexer.h" "client/reindexerconfig.h" "client/cororeindexer.h" "client/coroqueryresults.h" "client/corotransaction.h" "client/connectopts.h" "client/queryresults.h" "client/transaction.h" "net/ev/ev.h" "vendor/koishi/include/koishi.h" "coroutine/coroutine.h" "coroutine/channel.h" "coroutine/waitgroup.h" + "vendor/expected/expected.h" "debug/backtrace.h" "debug/allocdebug.h" "debug/resolver.h" "vendor/gason/gason.h" ) diff --git a/cpp_src/cmd/reindexer_server/CMakeLists.txt b/cpp_src/cmd/reindexer_server/CMakeLists.txt index cebb83d18..8149976b2 100644 --- a/cpp_src/cmd/reindexer_server/CMakeLists.txt +++ b/cpp_src/cmd/reindexer_server/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.10) +cmake_minimum_required(VERSION 3.18) project(reindexer_server) diff --git a/cpp_src/cmd/reindexer_server/contrib/Dockerfile b/cpp_src/cmd/reindexer_server/contrib/Dockerfile index d60226e67..b3b6fe89e 100644 --- a/cpp_src/cmd/reindexer_server/contrib/Dockerfile +++ b/cpp_src/cmd/reindexer_server/contrib/Dockerfile @@ -1,18 +1,24 @@ FROM alpine:3.20 AS build RUN cd /tmp && apk update && \ - apk add git curl autoconf automake libtool linux-headers g++ make libunwind-dev grpc-dev protobuf-dev c-ares-dev patch openssl && \ + apk add git curl autoconf automake libtool linux-headers g++ make libunwind-dev grpc-dev protobuf-dev c-ares-dev patch && \ git clone https://github.com/gperftools/gperftools.git && \ cd gperftools && git checkout gperftools-2.15 && \ sed -i s/_sigev_un\._tid/sigev_notify_thread_id/ src/profile-handler.cc && \ ./autogen.sh && ./configure --disable-dependency-tracking && make -j6 && make install +RUN git clone https://github.com/OpenMathLib/OpenBLAS.git && cd OpenBLAS && \ + make TARGET=NEHALEM USE_THREAD=1 NUM_THREADS=16 NO_STATIC=1 NO_WARMUP=1 COMMON_OPT=-O2 DYNAMIC_ARCH=1 USE_OPENMP=0 && \ + make PREFIX=/usr/local install TARGET=NEHALEM USE_THREAD=1 NUM_THREADS=16 NO_STATIC=1 NO_WARMUP=1 COMMON_OPT=-O2 DYNAMIC_ARCH=1 USE_OPENMP=0 + ADD . /src WORKDIR /src -RUN ./dependencies.sh && \ - mkdir build && \ +# Install reindexer dependecies, except lapack/openblas +RUN apk add snappy-dev leveldb-dev libunwind-dev make curl cmake unzip git openssl-dev + +RUN mkdir build && \ cd build && \ cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo .. -DENABLE_GRPC=On -DGRPC_PACKAGE_PROVIDER="" && \ make -j6 reindexer_server reindexer_tool && \ @@ -27,7 +33,7 @@ FROM alpine:3.20 COPY --from=build /usr/local /usr/local COPY --from=build /entrypoint.sh /entrypoint.sh -RUN apk update && apk add libstdc++ libunwind snappy leveldb c-ares libprotobuf xz-libs grpc-cpp && rm -rf /var/cache/apk/* +RUN apk update && apk add libstdc++ libunwind snappy leveldb c-ares libgomp libprotobuf xz-libs grpc-cpp && rm -rf /var/cache/apk/* RUN ln -s /usr/lib/libcrypto.so.3 /usr/lib/libcrypto.so && \ ln -s /usr/lib/libssl.so.3 /usr/lib/libssl.so diff --git a/cpp_src/cmd/reindexer_server/contrib/Dockerfile.deb b/cpp_src/cmd/reindexer_server/contrib/Dockerfile.deb index 2965c7001..cf468c766 100644 --- a/cpp_src/cmd/reindexer_server/contrib/Dockerfile.deb +++ b/cpp_src/cmd/reindexer_server/contrib/Dockerfile.deb @@ -1,7 +1,7 @@ FROM debian:stable-slim AS build RUN apt update -y && apt install -y libunwind-dev build-essential libsnappy-dev libleveldb-dev openssl \ - make curl unzip git cmake libjemalloc-dev \ + make curl unzip git cmake libjemalloc-dev libopenblas-pthread-dev \ libgrpc++-dev protobuf-compiler-grpc protobuf-compiler libprotobuf-dev ADD . /src @@ -19,7 +19,7 @@ RUN cd /src && \ FROM debian:stable-slim COPY --from=build /usr/local /usr/local COPY --from=build /entrypoint.sh /entrypoint.sh -RUN apt update -y && apt install -y libleveldb1d libunwind8 libjemalloc2 libgrpc++1.51 && rm -rf /var/lib/apt +RUN apt update -y && apt install -y libleveldb1d libunwind8 libopenblas0-openmp libjemalloc2 libgrpc++1.51 && rm -rf /var/lib/apt RUN ln -s /usr/lib/x86_64-linux-gnu/libcrypto.so.3 /usr/lib/x86_64-linux-gnu/libcrypto.so && \ ln -s /usr/lib/x86_64-linux-gnu/libssl.so.3 /usr/lib/x86_64-linux-gnu/libssl.so @@ -30,6 +30,8 @@ ENV RX_HTTPLOG=stdout ENV RX_RPCLOG=stdout ENV RX_SERVERLOG=stdout ENV RX_LOGLEVEL=info +# Number of thread to build IVF's centroids +ENV RX_IVF_OMP_THREADS=8 RUN chmod +x /entrypoint.sh diff --git a/cpp_src/cmd/reindexer_server/test/check_rx_version.sh b/cpp_src/cmd/reindexer_server/test/check_rx_version.sh index 3eec3cb06..127aecb85 100755 --- a/cpp_src/cmd/reindexer_server/test/check_rx_version.sh +++ b/cpp_src/cmd/reindexer_server/test/check_rx_version.sh @@ -3,22 +3,22 @@ PACKAGE=$1 if [ "$PACKAGE" == "deb" ]; then - RX_SERVER_REQUIRED_VERSION="$(basename build/reindexer-4-server*.deb .deb)" + RX_SERVER_REQUIRED_VERSION="$(basename build/reindexer-server*.deb .deb)" RX_SERVER_REQUIRED_VERSION=$(echo "$RX_SERVER_REQUIRED_VERSION" | cut -d'_' -f 2) - RX_SERVER_INSTALLED_VERSION="$(dpkg -s reindexer-4-server | grep Version)" + RX_SERVER_INSTALLED_VERSION="$(dpkg -s reindexer-server | grep Version)" RX_SERVER_INSTALLED_VERSION="${RX_SERVER_INSTALLED_VERSION#*: }" elif [ "$PACKAGE" == "rpm" ]; then - RX_SERVER_REQUIRED_VERSION="$(basename build/reindexer-4-server*.rpm .rpm)" + RX_SERVER_REQUIRED_VERSION="$(basename build/reindexer-server*.rpm .rpm)" OS=$(echo ${ID} | tr '[:upper:]' '[:lower:]') if [ "$OS" = "redos" ]; then - RX_SERVER_INSTALLED_VERSION="$(dnf list installed \"reindexer-4-server\" | tail -n 1 | awk \'{print $$2}\')" + RX_SERVER_INSTALLED_VERSION="$(dnf list installed \"reindexer-server\" | tail -n 1 | awk \'{print $$2}\')" echo RX_SERVER_INSTALLED_VERSION=$RX_SERVER_INSTALLED_VERSION echo "Installed!!!" - dnf list installed \"reindexer-4-server\" + dnf list installed \"reindexer-server\" echo "More!!!" - dnf list installed \"reindexer-4-server\" | tail -n 1 + dnf list installed \"reindexer-server\" | tail -n 1 else - RX_SERVER_INSTALLED_VERSION="$(rpm -q reindexer-4-server)" + RX_SERVER_INSTALLED_VERSION="$(rpm -q reindexer-server)" fi else echo "Unknown package extension" diff --git a/cpp_src/cmd/reindexer_server/test/test_storage_compatibility.sh b/cpp_src/cmd/reindexer_server/test/test_storage_compatibility.sh new file mode 100755 index 000000000..47d43d343 --- /dev/null +++ b/cpp_src/cmd/reindexer_server/test/test_storage_compatibility.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# Task: https://github.com/restream/reindexer/-/issues/1188 +set -e + +function KillAndRemoveServer { + local pid=$1 + kill $pid + wait $pid + yum remove -y 'reindexer*' > /dev/null +} + +function WaitForDB { + # wait until DB is loaded + set +e # disable "exit on error" so the script won't stop when DB's not loaded yet + is_connected=$(reindexer_tool --dsn $ADDRESS --command '\databases list'); + while [[ $is_connected != "test" ]] + do + sleep 2 + is_connected=$(reindexer_tool --dsn $ADDRESS --command '\databases list'); + done + set -e +} + +function CompareNamespacesLists { + local ns_list_actual=$1 + local ns_list_expected=$2 + local pid=$3 + + diff=$(echo ${ns_list_actual[@]} ${ns_list_expected[@]} | tr ' ' '\n' | sort | uniq -u) # compare in any order + if [ "$diff" == "" ]; then + echo "## PASS: namespaces list not changed" + else + echo "##### FAIL: namespaces list was changed" + echo "expected: $ns_list_expected" + echo "actual: $ns_list_actual" + KillAndRemoveServer $pid; + exit 1 + fi +} + +function CompareMemstats { + local actual=$1 + local expected=$2 + local pid=$3 + diff=$(echo ${actual[@]} ${expected[@]} | tr ' ' '\n' | sed 's/\(.*\),$/\1/' | sort | uniq -u) # compare in any order + if [ "$diff" == "" ]; then + echo "## PASS: memstats not changed" + else + echo "##### FAIL: memstats was changed" + echo "expected: $expected" + echo "actual: $actual" + KillAndRemoveServer $pid; + exit 1 + fi +} + + +RX_SERVER_CURRENT_VERSION_RPM="$(basename build/reindexer-*server*.rpm)" +VERSION_FROM_RPM=$(echo "$RX_SERVER_CURRENT_VERSION_RPM" | grep -o '.*server-..') +VERSION=$(echo ${VERSION_FROM_RPM: -2:1}) # one-digit version + +echo "## choose latest release rpm file" +if [ $VERSION == 3 ]; then + LATEST_RELEASE=$(python3 cpp_src/cmd/reindexer_server/test/get_last_rx_version.py -v 3) + namespaces_list_expected=$'purchase_options_ext_dict\nchild_account_recommendations\n#config\n#activitystats\nradio_channels\ncollections\n#namespaces\nwp_imports_tasks\nepg_genres\nrecom_media_items_personal\nrecom_epg_archive_default\n#perfstats\nrecom_epg_live_default\nmedia_view_templates\nasset_video_servers\nwp_tasks_schedule\nadmin_roles\n#clientsstats\nrecom_epg_archive_personal\nrecom_media_items_similars\nmenu_items\naccount_recommendations\nkaraoke_items\nmedia_items\nbanners\n#queriesperfstats\nrecom_media_items_default\nrecom_epg_live_personal\nservices\n#memstats\nchannels\nmedia_item_recommendations\nwp_tasks_tasks\nepg' +elif [ $VERSION == 4 -o $VERSION == 5 ]; then + # TODO: V5 should use basic reindexer name without '-4'/'-5' infix + LATEST_RELEASE=$(python3 cpp_src/cmd/reindexer_server/test/get_last_rx_version.py -v 4) + # replicationstats ns added for v4 + namespaces_list_expected=$'purchase_options_ext_dict\nchild_account_recommendations\n#config\n#activitystats\n#replicationstats\nradio_channels\ncollections\n#namespaces\nwp_imports_tasks\nepg_genres\nrecom_media_items_personal\nrecom_epg_archive_default\n#perfstats\nrecom_epg_live_default\nmedia_view_templates\nasset_video_servers\nwp_tasks_schedule\nadmin_roles\n#clientsstats\nrecom_epg_archive_personal\nrecom_media_items_similars\nmenu_items\naccount_recommendations\nkaraoke_items\nmedia_items\nbanners\n#queriesperfstats\nrecom_media_items_default\nrecom_epg_live_personal\nservices\n#memstats\nchannels\nmedia_item_recommendations\nwp_tasks_tasks\nepg' +else + echo "Unknown version" + exit 1 +fi + +echo "## downloading latest release rpm file: $LATEST_RELEASE" +curl "http://repo.itv.restr.im/itv-api-ng/7/x86_64/$LATEST_RELEASE" --output $LATEST_RELEASE; +echo "## downloading example DB" +curl "https://github.com/restream/reindexer_testdata/-/raw/main/dump_demo.zip" --output dump_demo.zip; +unzip -o dump_demo.zip # unzips into demo_test.rxdump; + +ADDRESS="cproto://127.0.0.1:6534/" +DB_NAME="test" + +memstats_expected=$'[ +{"name":"account_recommendations","replication":{"data_hash":6833710705,"data_count":1}}, +{"name":"admin_roles","replication":{"data_hash":1896088071,"data_count":2}}, +{"name":"asset_video_servers","replication":{"data_hash":7404222244,"data_count":97}}, +{"name":"banners","replication":{"data_hash":0,"data_count":0}}, +{"name":"channels","replication":{"data_hash":457292509431319,"data_count":3941}}, +{"name":"child_account_recommendations","replication":{"data_hash":6252344969,"data_count":1}}, +{"name":"collections","replication":{"data_hash":0,"data_count":0}}, +{"name":"epg","replication":{"data_hash":-7049751653258,"data_count":1623116}}, +{"name":"epg_genres","replication":{"data_hash":8373644068,"data_count":1315}}, +{"name":"karaoke_items","replication":{"data_hash":5858155773472,"data_count":4500}}, +{"name":"media_item_recommendations","replication":{"data_hash":-6520334670,"data_count":35886}}, +{"name":"media_items","replication":{"data_hash":-1824301168479972392,"data_count":65448}}, +{"name":"media_view_templates","replication":{"data_hash":0,"data_count":0}}, +{"name":"menu_items","replication":{"data_hash":0,"data_count":0}}, +{"name":"purchase_options_ext_dict","replication":{"data_hash":24651210926,"data_count":3}}, +{"name":"radio_channels","replication":{"data_hash":37734732881,"data_count":28}}, +{"name":"recom_epg_archive_default","replication":{"data_hash":0,"data_count":0}}, +{"name":"recom_epg_archive_personal","replication":{"data_hash":0,"data_count":0}}, +{"name":"recom_epg_live_default","replication":{"data_hash":0,"data_count":0}}, +{"name":"recom_epg_live_personal","replication":{"data_hash":0,"data_count":0}}, +{"name":"recom_media_items_default","replication":{"data_hash":8288213744,"data_count":3}}, +{"name":"recom_media_items_personal","replication":{"data_hash":0,"data_count":0}}, +{"name":"recom_media_items_similars","replication":{"data_hash":-672103903,"data_count":33538}}, +{"name":"services","replication":{"data_hash":0,"data_count":0}}, +{"name":"wp_imports_tasks","replication":{"data_hash":777859741066,"data_count":1145}}, +{"name":"wp_tasks_schedule","replication":{"data_hash":12595790956,"data_count":4}}, +{"name":"wp_tasks_tasks","replication":{"data_hash":28692716680,"data_count":281}} +] +Returned 27 rows' + +echo "##### Forward compatibility test #####" + +DB_PATH=$(pwd)"/rx_db" + +echo "Database: "$DB_PATH + +echo "## installing latest release: $LATEST_RELEASE" +yum install -y $LATEST_RELEASE > /dev/null; +# run RX server with disabled logging +reindexer_server -l warning --httplog=none --rpclog=none --db $DB_PATH & +server_pid=$! +sleep 2; + +reindexer_tool --dsn $ADDRESS$DB_NAME -f demo_test.rxdump --createdb; +sleep 1; + +namespaces_1=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command '\namespaces list'); +echo $namespaces_1; +CompareNamespacesLists "${namespaces_1[@]}" "${namespaces_list_expected[@]}" $server_pid; + +memstats_1=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command 'select name, replication.data_hash, replication.data_count from #memstats order by name'); +CompareMemstats "${memstats_1[@]}" "${memstats_expected[@]}" $server_pid; + +KillAndRemoveServer $server_pid; + +echo "## installing current version: $RX_SERVER_CURRENT_VERSION_RPM" +yum install -y build/*.rpm > /dev/null; +reindexer_server -l0 --corelog=none --httplog=none --rpclog=none --db $DB_PATH & +server_pid=$! +sleep 2; + +WaitForDB + +namespaces_2=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command '\namespaces list'); +echo $namespaces_2; +CompareNamespacesLists "${namespaces_2[@]}" "${namespaces_1[@]}" $server_pid; + + +memstats_2=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command 'select name, replication.data_hash, replication.data_count from #memstats order by name'); +CompareMemstats "${memstats_2[@]}" "${memstats_1[@]}" $server_pid; + +KillAndRemoveServer $server_pid; +rm -rf $DB_PATH; +sleep 1; + +echo "##### Backward compatibility test #####" + +echo "## installing current version: $RX_SERVER_CURRENT_VERSION_RPM" +yum install -y build/*.rpm > /dev/null; +reindexer_server -l warning --httplog=none --rpclog=none --db $DB_PATH & +server_pid=$! +sleep 2; + +reindexer_tool --dsn $ADDRESS$DB_NAME -f demo_test.rxdump --createdb; +sleep 1; + +namespaces_3=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command '\namespaces list'); +echo $namespaces_3; +CompareNamespacesLists "${namespaces_3[@]}" "${namespaces_list_expected[@]}" $server_pid; + + +memstats_3=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command 'select name, replication.data_hash, replication.data_count from #memstats order by name'); +CompareMemstats "${memstats_3[@]}" "${memstats_expected[@]}" $server_pid; + +KillAndRemoveServer $server_pid; + +echo "## installing latest release: $LATEST_RELEASE" +yum install -y $LATEST_RELEASE > /dev/null; +reindexer_server -l warning --httplog=none --rpclog=none --db $DB_PATH & +server_pid=$! +sleep 2; + +WaitForDB + +namespaces_4=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command '\namespaces list'); +echo $namespaces_4; +CompareNamespacesLists "${namespaces_4[@]}" "${namespaces_3[@]}" $server_pid; + +memstats_4=$(reindexer_tool --dsn $ADDRESS$DB_NAME --command 'select name, replication.data_hash, replication.data_count from #memstats order by name'); +CompareMemstats "${memstats_4[@]}" "${memstats_3[@]}" $server_pid; + +KillAndRemoveServer $server_pid; +rm -rf $DB_PATH; diff --git a/cpp_src/cmd/reindexer_tool/CMakeLists.txt b/cpp_src/cmd/reindexer_tool/CMakeLists.txt index f13bcbbc1..9bc1adcbf 100644 --- a/cpp_src/cmd/reindexer_tool/CMakeLists.txt +++ b/cpp_src/cmd/reindexer_tool/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.10) +cmake_minimum_required(VERSION 3.18) project(reindexer_tool) @@ -42,6 +42,7 @@ add_executable(${TARGET} ${SRCS}) # Enable export to provide readable stacktraces set_property(TARGET ${TARGET} PROPERTY ENABLE_EXPORTS 1) +set_property(TARGET ${TARGET} PROPERTY DEFINE_SYMBOL "") if (NOT MSVC AND NOT WITH_STDLIB_DEBUG) if (NOT ReplXX_LIBRARY OR NOT ReplXX_INCLUDE_DIR) diff --git a/cpp_src/cmd/reindexer_tool/commandsexecutor.cc b/cpp_src/cmd/reindexer_tool/commandsexecutor.cc index 062aeec81..d4c37be2c 100644 --- a/cpp_src/cmd/reindexer_tool/commandsexecutor.cc +++ b/cpp_src/cmd/reindexer_tool/commandsexecutor.cc @@ -4,7 +4,9 @@ #include "cluster/config.h" #include "core/cjson/jsonbuilder.h" #include "core/reindexer.h" +#include "core/system_ns_names.h" #include "coroutine/waitgroup.h" +#include "estl/gift_str.h" #include "executorscommand.h" #include "tableviewscroller.h" #include "tools/catch_and_return.h" @@ -27,7 +29,7 @@ const std::string kVariableOutput = "output"; const std::string kOutputModeJson = "json"; const std::string kOutputModeTable = "table"; const std::string kOutputModePretty = "pretty"; -const std::string kVariableWithShardId = "with_shard_id"; +const std::string kVariableWithShardId = "with_shard_ids"; const std::string kBenchNamespace = "rxtool_bench"; const std::string kBenchIndex = "id"; const std::string kDumpModePrefix = "-- __dump_mode:"; @@ -132,7 +134,7 @@ bool CommandsExecutor::isHavingReplicationConfig(WrSerializer& wser Query q; typename DBInterface::QueryResultsT results(kResultsWithPayloadTypes | kResultsCJson | kResultsWithItemID); - auto err = db().Select(Query("#replicationstats").Where("type", CondEq, type), results); + auto err = db().Select(Query(reindexer::kReplicationStatsNamespace).Where("type", CondEq, type), results); if (!err.ok()) { throw err; } @@ -276,13 +278,17 @@ Error CommandsExecutor::runImpl(const std::string& dsn, Args&&... a for (auto node : value) { WrSerializer ser; reindexer::jsonValueToString(node.value, ser, 0, 0, false); - variables_[kVariableOutput] = std::string(ser.Slice()); + if (std::string_view(node.key) == kVariableOutput) { + variables_[kVariableOutput] = std::string(ser.Slice()); + } else if (std::string_view(node.key) == kVariableWithShardId) { + variables_[kVariableWithShardId] = std::string(ser.Slice()); + } } } catch (const gason::Exception& e) { err = Error(errParseJson, "Unable to parse output mode: %s", e.what()); } } - if (err.ok() && (variables_.empty() || variables_.find(kVariableOutput) == variables_.end())) { + if (err.ok() && variables_.find(kVariableOutput) == variables_.end()) { variables_[kVariableOutput] = kOutputModeJson; } if (err.ok() && !uri_.parse(dsn)) { @@ -645,7 +651,11 @@ std::vector ToJSONVector(const QueryResultsT& r) { template Error CommandsExecutor::commandSelect(const std::string& command) noexcept { try { - typename DBInterface::QueryResultsT results(kResultsWithPayloadTypes | kResultsCJson | kResultsWithItemID | kResultsWithRaw); + int flags = kResultsWithPayloadTypes | kResultsCJson | kResultsWithItemID | kResultsWithRaw; + if (variables_[kVariableWithShardId] == "on") { + flags |= kResultsNeedOutputShardId | kResultsWithShardId; + } + typename DBInterface::QueryResultsT results(flags); const auto q = Query::FromSQL(command); auto err = db().Select(q, results); @@ -779,7 +789,7 @@ Error CommandsExecutor::commandUpsert(const std::string& command) { } using namespace std::string_view_literals; - if (fromFile_ && std::string_view(nsName) == "#config"sv) { + if (fromFile_ && std::string_view(nsName) == reindexer::kConfigNamespace) { try { gason::JsonParser p; auto root = p.Parse(parser.CurPtr()); @@ -912,7 +922,7 @@ Error CommandsExecutor::commandDump(const std::string& command) { for (auto& nsDef : doNsDefs) { // skip system namespaces, except #config - if (reindexer::isSystemNamespaceNameFast(nsDef.name) && nsDef.name != "#config") { + if (reindexer::isSystemNamespaceNameFast(nsDef.name) && nsDef.name != reindexer::kConfigNamespace) { continue; } @@ -1352,7 +1362,7 @@ reindexer::Error CommandsExecutor::filterNamespacesByDumpMode(std:: } typename DBInterface::QueryResultsT qr; - auto err = db().Select(Query("#config").Where("type", CondEq, "sharding"), qr); + auto err = db().Select(Query(reindexer::kConfigNamespace).Where("type", CondEq, "sharding"), qr); if (!err.ok()) { return err; } diff --git a/cpp_src/cmd/reindexer_tool/commandsexecutor.h b/cpp_src/cmd/reindexer_tool/commandsexecutor.h index 876bc0c06..cb717ce89 100644 --- a/cpp_src/cmd/reindexer_tool/commandsexecutor.h +++ b/cpp_src/cmd/reindexer_tool/commandsexecutor.h @@ -195,7 +195,7 @@ class CommandsExecutor { - 'json' Unformatted JSON - 'pretty' Pretty printed JSON - 'table' Table view - -'with_shard_id' + -'with_shard_ids' possible values: - 'on' Add '#shard_id' field to items from sharded namespaces - 'off' diff --git a/cpp_src/cmd/reindexer_tool/dumpoptions.cc b/cpp_src/cmd/reindexer_tool/dumpoptions.cc index 002fe6b21..befdbda57 100644 --- a/cpp_src/cmd/reindexer_tool/dumpoptions.cc +++ b/cpp_src/cmd/reindexer_tool/dumpoptions.cc @@ -31,7 +31,7 @@ std::string_view DumpOptions::StrFromMode(Mode mode) { } } -Error DumpOptions::FromJSON(reindexer::span json) { +Error DumpOptions::FromJSON(std::span json) { try { gason::JsonParser parser; auto root = parser.Parse(json); diff --git a/cpp_src/cmd/reindexer_tool/dumpoptions.h b/cpp_src/cmd/reindexer_tool/dumpoptions.h index 699de6a88..2147cfc82 100644 --- a/cpp_src/cmd/reindexer_tool/dumpoptions.h +++ b/cpp_src/cmd/reindexer_tool/dumpoptions.h @@ -1,6 +1,6 @@ #pragma once -#include "estl/span.h" +#include #include "tools/errors.h" namespace reindexer { @@ -16,7 +16,7 @@ struct DumpOptions { static Mode ModeFromStr(std::string_view mode); static std::string_view StrFromMode(Mode mode); - reindexer::Error FromJSON(reindexer::span json); + reindexer::Error FromJSON(std::span json); void GetJSON(reindexer::WrSerializer& ser) const; }; diff --git a/cpp_src/core/activity_context.h b/cpp_src/core/activity_context.h index e15d8a9a8..9ecd8c22f 100644 --- a/cpp_src/core/activity_context.h +++ b/cpp_src/core/activity_context.h @@ -8,9 +8,7 @@ #include #include #include "activity.h" -#include "activitylog.h" #include "estl/mutex.h" -#include "tools/clock.h" namespace reindexer { diff --git a/cpp_src/core/cbinding/reindexer_c.cc b/cpp_src/core/cbinding/reindexer_c.cc index b0170e936..e70fb09f2 100644 --- a/cpp_src/core/cbinding/reindexer_c.cc +++ b/cpp_src/core/cbinding/reindexer_c.cc @@ -1,13 +1,12 @@ #include "reindexer_c.h" #include -#include #include -#include #include "cgocancelcontextpool.h" #include "core/cjson/baseencoder.h" #include "debug/crashqueryreporter.h" +#include "estl/gift_str.h" #include "estl/syncpool.h" #include "events/subscriber_config.h" #include "reindexer_version.h" @@ -62,7 +61,7 @@ static reindexer_array_ret arr_ret2c(const Error& err_, reindexer_buffer* out, u return ret; } -static uint32_t span2arr(span d, reindexer_buffer* out, uint32_t out_size) { +static uint32_t span2arr(std::span d, reindexer_buffer* out, uint32_t out_size) { const auto sz = std::min(d.size(), size_t(out_size)); for (uint32_t i = 0; i < sz; ++i) { out[i].data = d[i].data(); @@ -130,11 +129,11 @@ static void results2c(std::unique_ptr result, struct reinde std::string_view rawBufOut; if (rawResProxying) { result->ser.SetOpts({.flags = flags, - .ptVersions = span(pt_versions, pt_versions_count), + .ptVersions = std::span(pt_versions, pt_versions_count), .fetchOffset = 0, .fetchLimit = INT_MAX, .withAggregations = true}); - result->ser.PutResultsRaw(result.get(), &rawBufOut); + result->ser.PutResultsRaw(*result, &rawBufOut); out->len = rawBufOut.size() ? rawBufOut.size() : result->ser.Len(); out->data = rawBufOut.size() ? uintptr_t(rawBufOut.data()) : uintptr_t(result->ser.Buf()); } else { @@ -143,11 +142,11 @@ static void results2c(std::unique_ptr result, struct reinde flags |= kResultsWithPayloadTypes; } result->ser.SetOpts({.flags = flags, - .ptVersions = span(pt_versions, pt_versions_count), + .ptVersions = std::span(pt_versions, pt_versions_count), .fetchOffset = 0, .fetchLimit = INT_MAX, .withAggregations = true}); - result->ser.PutResults(result.get(), bindingCaps.load(std::memory_order_relaxed), &result->proxiedRefsStorage); + result->ser.PutResults(*result, bindingCaps.load(std::memory_order_relaxed), &result->proxiedRefsStorage); out->len = result->ser.Len(); out->data = uintptr_t(result->ser.Buf()); } @@ -161,17 +160,16 @@ static void results2c(std::unique_ptr result, struct reinde uintptr_t init_reindexer() { reindexer_init_locale(); static std::atomic dbsCounter = {0}; - ReindexerWrapper* db = new ReindexerWrapper(std::move(ReindexerConfig().WithDBName(fmt::sprintf("builtin_db_%d", dbsCounter++)))); + auto db = new ReindexerWrapper(std::move(ReindexerConfig().WithDBName(fmt::sprintf("builtin_db_%d", dbsCounter++)))); return reinterpret_cast(db); } uintptr_t init_reindexer_with_config(reindexer_config config) { reindexer_init_locale(); - ReindexerWrapper* db = - new ReindexerWrapper(std::move(ReindexerConfig() - .WithAllocatorCacheLimits(config.allocator_cache_limit, config.allocator_max_cache_part) - .WithDBName(str2c(config.sub_db_name)) - .WithUpdatesSize(config.max_updates_size))); + auto db = new ReindexerWrapper(std::move(ReindexerConfig() + .WithAllocatorCacheLimits(config.allocator_cache_limit, config.allocator_max_cache_part) + .WithDBName(str2c(config.sub_db_name)) + .WithUpdatesSize(config.max_updates_size))); return reinterpret_cast(db); } @@ -220,9 +218,9 @@ reindexer_error reindexer_modify_item_packed_tx(uintptr_t rx, uintptr_t tr, rein } Serializer ser(args.data, args.len); - int format = ser.GetVarUint(); - int mode = ser.GetVarUint(); - int state_token = ser.GetVarUint(); + int format = ser.GetVarUInt(); + int mode = ser.GetVarUInt(); + int state_token = ser.GetVarUInt(); Error err = err_not_init; auto item = trw->tr_.NewItem(); proccess_packed_item(item, mode, state_token, data, format, err); @@ -234,7 +232,7 @@ reindexer_error reindexer_modify_item_packed_tx(uintptr_t rx, uintptr_t tr, rein } } if (err.ok()) { - unsigned preceptsCount = ser.GetVarUint(); + unsigned preceptsCount = ser.GetVarUInt(); std::vector precepts; precepts.reserve(preceptsCount); while (preceptsCount--) { @@ -252,9 +250,9 @@ reindexer_ret reindexer_modify_item_packed(uintptr_t rx, reindexer_buffer args, try { Serializer ser(args.data, args.len); std::string_view ns = ser.GetVString(); - int format = ser.GetVarUint(); - int mode = ser.GetVarUint(); - int state_token = ser.GetVarUint(); + int format = ser.GetVarUInt(); + int mode = ser.GetVarUInt(); + int state_token = ser.GetVarUInt(); Error err = err_not_init; if (rx) { @@ -266,7 +264,7 @@ reindexer_ret reindexer_modify_item_packed(uintptr_t rx, reindexer_buffer args, query_results_ptr res; if (err.ok()) { - unsigned preceptsCount = ser.GetVarUint(); + unsigned preceptsCount = ser.GetVarUInt(); const bool needSaveItemValueInQR = preceptsCount; std::vector precepts; precepts.reserve(preceptsCount); @@ -447,14 +445,12 @@ reindexer_error reindexer_add_index(uintptr_t rx, reindexer_string nsName, reind if (rx) { CGORdxCtxKeeper rdxKeeper(rx, ctx_info, ctx_pool); std::string json(str2cv(indexDefJson)); - IndexDef indexDef; - - auto err = indexDef.FromJSON(giftStr(json)); - if (!err.ok()) { - return error2c(err); + const auto indexDef = IndexDef::FromJSON(giftStr(json)); + if (!indexDef) { + return error2c(indexDef.error()); } - res = rdxKeeper.db().AddIndex(str2cv(nsName), indexDef); + res = rdxKeeper.db().AddIndex(str2cv(nsName), *indexDef); } return error2c(res); } @@ -464,14 +460,13 @@ reindexer_error reindexer_update_index(uintptr_t rx, reindexer_string nsName, re if (rx) { CGORdxCtxKeeper rdxKeeper(rx, ctx_info, ctx_pool); std::string json(str2cv(indexDefJson)); - IndexDef indexDef; - auto err = indexDef.FromJSON(giftStr(json)); - if (!err.ok()) { - return error2c(err); + const auto indexDef = IndexDef::FromJSON(giftStr(json)); + if (!indexDef) { + return error2c(indexDef.error()); } - res = rdxKeeper.db().UpdateIndex(str2cv(nsName), indexDef); + res = rdxKeeper.db().UpdateIndex(str2cv(nsName), *indexDef); } return error2c(res); } @@ -494,9 +489,8 @@ reindexer_error reindexer_set_schema(uintptr_t rx, reindexer_string nsName, rein return error2c(res); } -// TODO: Rename this method, when all the connectors will support new version of connect -reindexer_error reindexer_connect_v4(uintptr_t rx, reindexer_string dsn, ConnectOpts opts, reindexer_string client_vers, - BindingCapabilities caps) { +reindexer_error reindexer_connect(uintptr_t rx, reindexer_string dsn, ConnectOpts opts, reindexer_string client_vers, + BindingCapabilities caps) { SemVersion cliVersion(str2cv(client_vers)); if (opts.options & kConnectOptWarnVersion) { SemVersion libVersion(REINDEX_VERSION); @@ -518,11 +512,6 @@ reindexer_error reindexer_connect_v4(uintptr_t rx, reindexer_string dsn, Connect return error2c(err); } -// TODO: Remove this wrapper, when all the connectors will support new version of connect -reindexer_error reindexer_connect(uintptr_t rx, reindexer_string dsn, ConnectOpts opts, reindexer_string client_vers) { - return reindexer_connect_v4(rx, dsn, opts, client_vers, BindingCapabilities{0}); -} - reindexer_error reindexer_init_system_namespaces(uintptr_t rx) { auto db = reinterpret_cast(rx); if (!db) { @@ -576,7 +565,7 @@ reindexer_ret reindexer_select_query(uintptr_t rx, struct reindexer_buffer in, i Query q = Query::Deserialize(ser); while (!ser.Eof()) { - const auto joinType = JoinType(ser.GetVarUint()); + const auto joinType = JoinType(ser.GetVarUInt()); JoinedQuery q1{joinType, Query::Deserialize(ser)}; if (q1.joinType == JoinType::Merge) { q.Merge(std::move(q1)); diff --git a/cpp_src/core/cbinding/reindexer_c.h b/cpp_src/core/cbinding/reindexer_c.h index 6823c886c..a27e6bc89 100644 --- a/cpp_src/core/cbinding/reindexer_c.h +++ b/cpp_src/core/cbinding/reindexer_c.h @@ -13,9 +13,8 @@ uintptr_t init_reindexer_with_config(reindexer_config config); void destroy_reindexer(uintptr_t rx); -reindexer_error reindexer_connect_v4(uintptr_t rx, reindexer_string dsn, ConnectOpts opts, reindexer_string client_vers, - BindingCapabilities caps); -reindexer_error reindexer_connect(uintptr_t rx, reindexer_string dsn, ConnectOpts opts, reindexer_string client_vers); +reindexer_error reindexer_connect(uintptr_t rx, reindexer_string dsn, ConnectOpts opts, reindexer_string client_vers, + BindingCapabilities caps); reindexer_error reindexer_ping(uintptr_t rx); reindexer_error reindexer_init_system_namespaces(uintptr_t rx); diff --git a/cpp_src/core/cbinding/resultserializer.cc b/cpp_src/core/cbinding/resultserializer.cc index 1b9254e6b..3c7f546ea 100644 --- a/cpp_src/core/cbinding/resultserializer.cc +++ b/cpp_src/core/cbinding/resultserializer.cc @@ -3,6 +3,7 @@ #include "core/cjson/tagsmatcher.h" #include "core/queryresults/joinresults.h" #include "core/queryresults/queryresults.h" +#include "core/type_consts.h" #include "tools/logger.h" #include "wal/walrecord.h" @@ -21,19 +22,19 @@ constexpr int kKnownResultsFlagsMask = int(GetKnownFlagsBitMask(kResultsFlagMaxV void WrResultSerializer::resetUnknownFlags() noexcept { opts_.flags &= kKnownResultsFlagsMask; } -void WrResultSerializer::putQueryParams(const BindingCapabilities& caps, QueryResults* results) { +void WrResultSerializer::putQueryParams(const BindingCapabilities& caps, QueryResults& results) { // Flags of present objects PutVarUint(opts_.flags); // Total - PutVarUint(results->TotalCount()); + PutVarUint(results.TotalCount()); // Count of returned items by query - PutVarUint(results->Count()); + PutVarUint(results.Count()); // Count of serialized items PutVarUint(opts_.fetchLimit); if (opts_.flags & kResultsWithPayloadTypes) { assertrx(opts_.ptVersions.data()); - const auto mergedNsCount = results->GetMergedNSCount(); + const auto mergedNsCount = results.GetMergedNSCount(); if (int(opts_.ptVersions.size()) != mergedNsCount) { logPrintf(LogWarning, "ptVersionsCount != results->GetMergedNSCount: %d != %d. Client's metadata can become inconsistent.", opts_.ptVersions.size(), mergedNsCount); @@ -45,9 +46,9 @@ void WrResultSerializer::putQueryParams(const BindingCapabilities& caps, QueryRe putExtraParams(caps, results); } -void WrResultSerializer::putExtraParams(const BindingCapabilities& caps, QueryResults* results) { +void WrResultSerializer::putExtraParams(const BindingCapabilities& caps, QueryResults& results) { if (opts_.withAggregations) { - for (const AggregationResult& aggregationRes : results->GetAggregationResults()) { + for (const AggregationResult& aggregationRes : results.GetAggregationResults()) { PutVarUint(QueryResultAggregation); auto slicePosSaver = StartSlice(); if ((opts_.flags & kResultsFormatMask) == kResultsMsgPack) { @@ -57,28 +58,28 @@ void WrResultSerializer::putExtraParams(const BindingCapabilities& caps, QueryRe } } - if (!results->GetExplainResults().empty()) { + if (!results.GetExplainResults().empty()) { PutVarUint(QueryResultExplain); - PutSlice(results->GetExplainResults()); + PutSlice(results.GetExplainResults()); } } if (opts_.flags & kResultsWithShardId) { - if (!results->IsDistributed() && results->Count() > 0) { + if (!results.IsDistributed() && results.Count() > 0) { PutVarUint(QueryResultShardId); - PutVarUint(results->GetCommonShardID()); + PutVarUint(results.GetCommonShardID()); opts_.flags &= ~kResultsWithShardId; // not set shardId for item } } if (caps.HasIncarnationTags()) { - int64_t shardingConfVer = results->GetShardingConfigVersion(); + int64_t shardingConfVer = results.GetShardingConfigVersion(); if (shardingConfVer != -1) { PutVarUint(QueryResultShardingVersion); PutVarUint(shardingConfVer); } PutVarUint(QueryResultIncarnationTags); - auto tags = results->GetIncarnationTags(); + auto tags = results.GetIncarnationTags(); PutVarUint(tags.size()); for (auto& shardTags : tags) { PutVarint(shardTags.shardId); @@ -89,6 +90,11 @@ void WrResultSerializer::putExtraParams(const BindingCapabilities& caps, QueryRe } } + if (caps.HasComplexRank() && results.HaveRank() && (opts_.flags & kResultsWithRank)) { + PutVarUint(QueryResultRankFormat); + PutVarUint(RankFormat::SingleFloatValue); + } + PutVarUint(QueryResultEnd); } @@ -97,7 +103,8 @@ static ItemRef GetItemRefWithStore(const LocalQueryResults::Iterator& it, QueryR static ItemRef GetItemRefWithStore(QueryResults::Iterator& it, QueryResults::ProxiedRefsStorage* storage) { return it.GetItemRef(storage); } template -void WrResultSerializer::putItemParams(ItT& it, int shardId, QueryResults::ProxiedRefsStorage* storage, const QueryResults* result) { +void WrResultSerializer::putItemParams(ItT& it, int shardId, QueryResults::ProxiedRefsStorage* storage, const QueryResults* result, + const BindingCapabilities& caps) { const auto itemRef = GetItemRefWithStore(it, storage); if (opts_.flags & kResultsWithItemID) { @@ -110,7 +117,12 @@ void WrResultSerializer::putItemParams(ItT& it, int shardId, QueryResults::Proxi } if (opts_.flags & kResultsWithRank) { - PutVarUint(itemRef.Proc()); + const RankT rank = it.IsRanked() ? it.GetItemRefRanked().Rank() : 0.0; + if (caps.HasComplexRank()) { + PutRank(rank); + } else { + PutVarUint(uint16_t(rank)); + } } if (opts_.flags & kResultsWithRaw) { @@ -174,15 +186,15 @@ void WrResultSerializer::putItemParams(ItT& it, int shardId, QueryResults::Proxi } } -void WrResultSerializer::putPayloadTypes(WrSerializer& ser, const QueryResults* results, const ResultFetchOpts& opts, int cnt, +void WrResultSerializer::putPayloadTypes(WrSerializer& ser, const QueryResults& results, const ResultFetchOpts& opts, int cnt, int totalCnt) { ser.PutVarUint(cnt); for (int nsid = 0; nsid < totalCnt; ++nsid) { - const TagsMatcher& tm = results->GetTagsMatcher(nsid); + const TagsMatcher& tm = results.GetTagsMatcher(nsid); if (int32_t(tm.version() ^ tm.stateToken()) != opts.ptVersions[nsid]) { ser.PutVarUint(nsid); - ser.PutVString(results->GetPayloadType(nsid)->Name()); - const PayloadType& t = results->GetPayloadType(nsid); + ser.PutVString(results.GetPayloadType(nsid)->Name()); + const PayloadType& t = results.GetPayloadType(nsid); // Serialize tags matcher ser.PutVarUint(tm.stateToken()); ser.PutVarUint(tm.version()); @@ -193,10 +205,10 @@ void WrResultSerializer::putPayloadTypes(WrSerializer& ser, const QueryResults* } } -std::pair WrResultSerializer::getPtUpdatesCount(const QueryResults* results) { +std::pair WrResultSerializer::getPtUpdatesCount(const QueryResults& results) { if (opts_.flags & kResultsWithPayloadTypes) { assertrx(opts_.ptVersions.data()); - const auto mergedNsCount = results->GetMergedNSCount(); + const auto mergedNsCount = results.GetMergedNSCount(); if (int(opts_.ptVersions.size()) != mergedNsCount) { logPrintf(LogWarning, "ptVersionsCount != results->GetMergedNSCount: %d != %d. Client's meta data can become incosistent.", opts_.ptVersions.size(), mergedNsCount); @@ -204,7 +216,7 @@ std::pair WrResultSerializer::getPtUpdatesCount(const QueryResults* re int cnt = 0, totalCnt = std::min(mergedNsCount, int(opts_.ptVersions.size())); for (int i = 0; i < totalCnt; i++) { - const TagsMatcher& tm = results->GetTagsMatcher(i); + const TagsMatcher& tm = results.GetTagsMatcher(i); if (int32_t(tm.version() ^ tm.stateToken()) != opts_.ptVersions[i]) { ++cnt; } @@ -214,37 +226,37 @@ std::pair WrResultSerializer::getPtUpdatesCount(const QueryResults* re return std::make_pair(0, 0); } -bool WrResultSerializer::PutResults(QueryResults* result, const BindingCapabilities& caps, QueryResults::ProxiedRefsStorage* storage) { - if (result->IsWALQuery() && !(opts_.flags & kResultsWithRaw) && (opts_.flags & kResultsFormatMask) != kResultsJson) { +bool WrResultSerializer::PutResults(QueryResults& result, const BindingCapabilities& caps, QueryResults::ProxiedRefsStorage* storage) { + if (result.IsWALQuery() && !(opts_.flags & kResultsWithRaw) && (opts_.flags & kResultsFormatMask) != kResultsJson) { throw Error(errParams, "Query results contain WAL items. Query results from WAL must either be requested in JSON format or with client, " "supporting RAW items"); } - if (opts_.fetchOffset > result->Count()) { - opts_.fetchOffset = result->Count(); + if (opts_.fetchOffset > result.Count()) { + opts_.fetchOffset = result.Count(); } - if (opts_.fetchOffset + opts_.fetchLimit > result->Count()) { - opts_.fetchLimit = result->Count() - opts_.fetchOffset; + if (opts_.fetchOffset + opts_.fetchLimit > result.Count()) { + opts_.fetchLimit = result.Count() - opts_.fetchOffset; } // Result has items from multiple namespaces, so pass nsid to each item - if (result->GetMergedNSCount() > 1) { + if (result.GetMergedNSCount() > 1) { opts_.flags |= kResultsWithNsID; } // Result has joined items, so pass them to client within items from main NS - if (result->HaveJoined()) { + if (result.HaveJoined()) { opts_.flags |= kResultsWithJoined; } - if (result->HaveRank()) { + if (result.HaveRank()) { opts_.flags |= kResultsWithRank; } - if (result->NeedOutputRank()) { + if (result.NeedOutputRank()) { opts_.flags |= kResultsNeedOutputRank; } // If data is not cacheable, just do not pass item's ID and LSN. Clients should not cache this data - if (!result->IsCacheEnabled()) { + if (!result.IsCacheEnabled()) { opts_.flags &= ~kResultsWithItemID; } // MsgPack items contain fields names so there is no need to transfer payload types @@ -254,26 +266,30 @@ bool WrResultSerializer::PutResults(QueryResults* result, const BindingCapabilit } // client with version 'compareVersionShardId' not support shardId - if (result->HaveShardIDs() && (opts_.flags & kResultsWithItemID) && !(opts_.flags & kResultsWithShardId)) { + const bool resultsHaveShardIDs = result.HaveShardIDs(); + if (resultsHaveShardIDs && (opts_.flags & kResultsWithItemID) && !(opts_.flags & kResultsWithShardId)) { if (caps.HasResultsWithShardIDs()) { opts_.flags |= kResultsWithShardId; } else { opts_.flags &= ~kResultsWithItemID; } } + if (!resultsHaveShardIDs) { + opts_.flags &= ~kResultsWithShardId; + } putQueryParams(caps, result); size_t saveLen = len_; const bool storeAsPointers = (opts_.flags & kResultsFormatMask) == kResultsPtrs; auto ptrStorage = storeAsPointers ? storage : nullptr; - if (ptrStorage && result->HasProxiedResults()) { - storage->reserve(result->HaveJoined() ? 2 * opts_.fetchLimit : opts_.fetchLimit); + if (ptrStorage && result.HasProxiedResults()) { + storage->reserve(result.HaveJoined() ? 2 * opts_.fetchLimit : opts_.fetchLimit); } - auto rowIt = result->begin() + opts_.fetchOffset; + auto rowIt = result.begin() + opts_.fetchOffset; for (unsigned i = 0, limit = opts_.fetchLimit; i < limit; ++i, ++rowIt) { // Put Item ID and version - putItemParams(rowIt, rowIt.GetShardId(), storage, result); + putItemParams(rowIt, rowIt.GetShardId(), storage, &result, caps); if (opts_.flags & kResultsWithJoined) { auto jIt = rowIt.GetJoined(storage); PutVarUint(jIt.getJoinedItemsCount() > 0 ? jIt.getJoinedFieldsCount() : 0); @@ -285,9 +301,9 @@ bool WrResultSerializer::PutResults(QueryResults* result, const BindingCapabilit continue; } LocalQueryResults qr = it.ToQueryResults(); - qr.addNSContext(*result, joinedField, lsn_t()); + qr.addNSContext(result, joinedField, lsn_t()); for (auto& jit : qr) { - putItemParams(jit, rowIt.GetShardId(), storage, nullptr); + putItemParams(jit, rowIt.GetShardId(), storage, nullptr, caps); } } } @@ -296,22 +312,22 @@ bool WrResultSerializer::PutResults(QueryResults* result, const BindingCapabilit grow((opts_.fetchLimit - 1) * (len_ - saveLen)); } } - return opts_.fetchOffset + opts_.fetchLimit >= result->Count(); + return opts_.fetchOffset + opts_.fetchLimit >= result.Count(); } -bool WrResultSerializer::PutResultsRaw(QueryResults* result, std::string_view* rawBufOut) { - if (opts_.fetchOffset > result->Count()) { - opts_.fetchOffset = result->Count(); +bool WrResultSerializer::PutResultsRaw(QueryResults& result, std::string_view* rawBufOut) { + if (opts_.fetchOffset > result.Count()) { + opts_.fetchOffset = result.Count(); } - if (opts_.fetchOffset + opts_.fetchLimit > result->Count()) { - opts_.fetchLimit = result->Count() - opts_.fetchOffset; + if (opts_.fetchOffset + opts_.fetchLimit > result.Count()) { + opts_.fetchLimit = result.Count() - opts_.fetchOffset; } - result->FetchRawBuffer(opts_.flags, opts_.fetchOffset, opts_.fetchLimit); + result.FetchRawBuffer(opts_.flags, opts_.fetchOffset, opts_.fetchLimit); client::ParsedQrRawBuffer raw; - const bool holdsRemoteData = result->GetRawProxiedBuffer(raw); + const bool holdsRemoteData = result.GetRawProxiedBuffer(raw); auto cntP = getPtUpdatesCount(result); auto& buf = *raw.buf; if (cntP.first) { diff --git a/cpp_src/core/cbinding/resultserializer.h b/cpp_src/core/cbinding/resultserializer.h index 1283c3f48..13c85d968 100644 --- a/cpp_src/core/cbinding/resultserializer.h +++ b/cpp_src/core/cbinding/resultserializer.h @@ -1,8 +1,6 @@ #pragma once -#include #include "core/queryresults/queryresults.h" -#include "estl/span.h" -#include "tools/semversion.h" +#include #include "tools/serializer.h" namespace reindexer { @@ -11,7 +9,7 @@ class QueryResults; struct ResultFetchOpts { int flags; - span ptVersions; + std::span ptVersions; unsigned fetchOffset; unsigned fetchLimit; bool withAggregations; @@ -32,8 +30,8 @@ class WrResultSerializer : public WrSerializer { resetUnknownFlags(); } - bool PutResults(QueryResults* results, const BindingCapabilities& caps, QueryResults::ProxiedRefsStorage* storage = nullptr); - bool PutResultsRaw(QueryResults* results, std::string_view* rawBufOut = nullptr); + bool PutResults(QueryResults& results, const BindingCapabilities& caps, QueryResults::ProxiedRefsStorage* storage = nullptr); + bool PutResultsRaw(QueryResults& results, std::string_view* rawBufOut = nullptr); void SetOpts(const ResultFetchOpts& opts) noexcept { opts_ = opts; } static bool IsRawResultsSupported(const BindingCapabilities& caps, const QueryResults& results) noexcept { return !results.HaveShardIDs() || caps.HasResultsWithShardIDs(); @@ -41,12 +39,13 @@ class WrResultSerializer : public WrSerializer { private: void resetUnknownFlags() noexcept; - void putQueryParams(const BindingCapabilities& caps, QueryResults* query); + void putQueryParams(const BindingCapabilities& caps, QueryResults& query); template - void putItemParams(ItT& it, int shardId, QueryResults::ProxiedRefsStorage* storage, const QueryResults* result); - void putExtraParams(const BindingCapabilities& caps, QueryResults* query); - static void putPayloadTypes(WrSerializer& ser, const QueryResults* results, const ResultFetchOpts& opts, int cnt, int totalCnt); - std::pair getPtUpdatesCount(const QueryResults* results); + void putItemParams(ItT& it, int shardId, QueryResults::ProxiedRefsStorage* storage, const QueryResults* result, + const BindingCapabilities&); + void putExtraParams(const BindingCapabilities& caps, QueryResults& query); + static void putPayloadTypes(WrSerializer& ser, const QueryResults& results, const ResultFetchOpts& opts, int cnt, int totalCnt); + std::pair getPtUpdatesCount(const QueryResults& results); ResultFetchOpts opts_; }; diff --git a/cpp_src/core/cbinding/updatesobserver.h b/cpp_src/core/cbinding/updatesobserver.h index 702a5bf32..1f39153fc 100644 --- a/cpp_src/core/cbinding/updatesobserver.h +++ b/cpp_src/core/cbinding/updatesobserver.h @@ -17,7 +17,7 @@ class BufferedUpdateObserver : public IEventsObserver { assertrx_dbg(queue_.capacity() - queue_.size() >= 1); queue_.write(user.Serialize(streamsMask, opts, rec)); } - [[nodiscard]] span TryReadUpdates() noexcept { return queue_.tail(); } + [[nodiscard]] std::span TryReadUpdates() noexcept { return queue_.tail(); } void EraseUpdates(size_t count) noexcept { queue_.erase_chunks(count); } private: diff --git a/cpp_src/core/cjson/baseencoder.cc b/cpp_src/core/cjson/baseencoder.cc index aabd273b0..b32720be5 100644 --- a/cpp_src/core/cjson/baseencoder.cc +++ b/cpp_src/core/cjson/baseencoder.cc @@ -164,15 +164,29 @@ bool BaseEncoder::encode(ConstPayload* pl, Serializer& rdser, Builder& int& cnt = fieldsoutcnt_[tagField]; switch (tagType) { case TAG_ARRAY: { - const auto count = rdser.GetVarUint(); + const auto count = rdser.GetVarUInt(); if (visible) { - pl->Type().Field(tagField).Type().EvaluateOneOf( + f.Type().EvaluateOneOf( [&](KeyValueType::Bool) { builder.Array(tagName, pl->GetArray(tagField).subspan(cnt, count), cnt); }, [&](KeyValueType::Int) { builder.Array(tagName, pl->GetArray(tagField).subspan(cnt, count), cnt); }, [&](KeyValueType::Int64) { builder.Array(tagName, pl->GetArray(tagField).subspan(cnt, count), cnt); }, [&](KeyValueType::Double) { builder.Array(tagName, pl->GetArray(tagField).subspan(cnt, count), cnt); }, [&](KeyValueType::String) { builder.Array(tagName, pl->GetArray(tagField).subspan(cnt, count), cnt); }, [&](KeyValueType::Uuid) { builder.Array(tagName, pl->GetArray(tagField).subspan(cnt, count), cnt); }, + [&](KeyValueType::FloatVector) { + objectScalarIndexes_.set(tagField); // Currently float vector is always single-value scalar + auto view = ConstFloatVectorView(pl->Get(tagField, 0)); + if rx_unlikely (view.IsStripped()) { + throw Error(errLogic, "Attempt to serialize stripped vector"); + } + assertrx_dbg(unsigned(view.Dimension()) == count); + builder.Array(tagName, view.Span(), 0); // Offset is always zero + }, + [&](KeyValueType::Float) { + // Indexed field can not contain float array now + assertrx(false); + abort(); + }, [](OneOf) noexcept { assertrx(false); abort(); @@ -194,6 +208,7 @@ bool BaseEncoder::encode(ConstPayload* pl, Serializer& rdser, Builder& case TAG_END: case TAG_OBJECT: case TAG_UUID: + case TAG_FLOAT: objectScalarIndexes_.set(tagField); if (visible) { builder.Put(tagName, pl->Get(tagField, cnt), cnt); @@ -210,9 +225,8 @@ bool BaseEncoder::encode(ConstPayload* pl, Serializer& rdser, Builder& if (atagType == TAG_OBJECT) { if (visible) { auto arrNode = builder.Array(tagName); - auto& lastIdxTag = indexedTagsPath_.back(); for (size_t i = 0; i < atagCount; ++i) { - lastIdxTag.SetIndex(i); + indexedTagsPath_.back().SetIndex(i); encode(pl, rdser, arrNode, true); } } else { @@ -247,7 +261,8 @@ bool BaseEncoder::encode(ConstPayload* pl, Serializer& rdser, Builder& case TAG_STRING: case TAG_NULL: case TAG_END: - case TAG_UUID: { + case TAG_UUID: + case TAG_FLOAT: { const KeyValueType kvt{tagType}; if (visible) { Variant value = rdser.GetRawVariant(kvt); @@ -284,7 +299,7 @@ bool BaseEncoder::collectTagsSizes(ConstPayload& pl, Serializer& rdser) assertrx(tagField < pl.NumFields()); switch (tagType) { case TAG_ARRAY: { - const auto count = rdser.GetVarUint(); + const auto count = rdser.GetVarUInt(); tagsLengths_.back() = count; break; } @@ -296,6 +311,7 @@ bool BaseEncoder::collectTagsSizes(ConstPayload& pl, Serializer& rdser) case TAG_OBJECT: case TAG_END: case TAG_UUID: + case TAG_FLOAT: break; } } else { @@ -330,9 +346,9 @@ bool BaseEncoder::collectTagsSizes(ConstPayload& pl, Serializer& rdser) case TAG_STRING: case TAG_NULL: case TAG_END: - case TAG_UUID: { + case TAG_UUID: + case TAG_FLOAT: rdser.SkipRawVariant(KeyValueType{tagType}); - } } } if (tagName && filter_) { diff --git a/cpp_src/core/cjson/cjsonbuilder.cc b/cpp_src/core/cjson/cjsonbuilder.cc index 9273e5b11..fc9c109ff 100644 --- a/cpp_src/core/cjson/cjsonbuilder.cc +++ b/cpp_src/core/cjson/cjsonbuilder.cc @@ -31,7 +31,7 @@ CJsonBuilder CJsonBuilder::Array(int tagName, ObjType type) { return CJsonBuilder(*ser_, type, tm_, tagName); } -void CJsonBuilder::Array(int tagName, span data, int /*offset*/) { +void CJsonBuilder::Array(int tagName, std::span data, int /*offset*/) { ser_->PutCTag(ctag{TAG_ARRAY, tagName}); ser_->PutCArrayTag(carraytag(data.size(), TAG_UUID)); for (auto d : data) { @@ -83,6 +83,17 @@ CJsonBuilder& CJsonBuilder::Put(int tagName, double arg, int /*offset*/) { return *this; } +CJsonBuilder& CJsonBuilder::Put(int tagName, float arg, int /*offset*/) { + if (type_ == ObjType::TypeArray) { + itemType_ = TAG_FLOAT; + } else { + putTag(tagName, TAG_FLOAT); + } + ser_->PutFloat(arg); + ++count_; + return *this; +} + CJsonBuilder& CJsonBuilder::Put(int tagName, std::string_view arg, int /*offset*/) { if (type_ == ObjType::TypeArray) { itemType_ = TAG_STRING; @@ -110,14 +121,15 @@ CJsonBuilder& CJsonBuilder::Null(int tagName) { return *this; } -CJsonBuilder& CJsonBuilder::Ref(int tagName, const Variant& v, int field) { - v.Type().EvaluateOneOf([&](OneOf) { ser_->PutCTag(ctag{TAG_VARINT, tagName, field}); }, - [&](KeyValueType::Bool) { ser_->PutCTag(ctag{TAG_BOOL, tagName, field}); }, - [&](KeyValueType::Double) { ser_->PutCTag(ctag{TAG_DOUBLE, tagName, field}); }, - [&](KeyValueType::String) { ser_->PutCTag(ctag{TAG_STRING, tagName, field}); }, - [&](KeyValueType::Uuid) { ser_->PutCTag(ctag{TAG_UUID, tagName, field}); }, - [&](OneOf) { ser_->PutCTag(ctag{TAG_NULL, tagName}); }, - [](OneOf) noexcept { std::abort(); }); +CJsonBuilder& CJsonBuilder::Ref(int tagName, const KeyValueType& type, int field) { + type.EvaluateOneOf([&](OneOf) { ser_->PutCTag(ctag{TAG_VARINT, tagName, field}); }, + [&](KeyValueType::Bool) { ser_->PutCTag(ctag{TAG_BOOL, tagName, field}); }, + [&](KeyValueType::Double) { ser_->PutCTag(ctag{TAG_DOUBLE, tagName, field}); }, + [&](KeyValueType::String) { ser_->PutCTag(ctag{TAG_STRING, tagName, field}); }, + [&](KeyValueType::Uuid) { ser_->PutCTag(ctag{TAG_UUID, tagName, field}); }, + [&](KeyValueType::Float) { ser_->PutCTag(ctag{TAG_FLOAT, tagName, field}); }, + [&](OneOf) { ser_->PutCTag(ctag{TAG_NULL, tagName}); }, + [](OneOf) noexcept { std::abort(); }); return *this; } @@ -128,19 +140,19 @@ CJsonBuilder& CJsonBuilder::ArrayRef(int tagName, int field, int count) { } CJsonBuilder& CJsonBuilder::Put(int tagName, const Variant& kv, int offset) { - kv.Type().EvaluateOneOf([&](KeyValueType::Int) { Put(tagName, int(kv), offset); }, - [&](KeyValueType::Int64) { Put(tagName, int64_t(kv), offset); }, - [&](KeyValueType::Double) { Put(tagName, double(kv), offset); }, - [&](KeyValueType::String) { Put(tagName, std::string_view(kv), offset); }, - [&](KeyValueType::Null) { Null(tagName); }, [&](KeyValueType::Bool) { Put(tagName, bool(kv), offset); }, - [&](KeyValueType::Tuple) { - auto arrNode = Array(tagName); - for (auto& val : kv.getCompositeValues()) { - arrNode.Put(nullptr, val); - } - }, - [&](KeyValueType::Uuid) { Put(tagName, Uuid{kv}, offset); }, - [](OneOf) noexcept {}); + kv.Type().EvaluateOneOf( + [&](KeyValueType::Int) { Put(tagName, int(kv), offset); }, [&](KeyValueType::Int64) { Put(tagName, int64_t(kv), offset); }, + [&](KeyValueType::Double) { Put(tagName, double(kv), offset); }, [&](KeyValueType::Float) { Put(tagName, float(kv), offset); }, + [&](KeyValueType::String) { Put(tagName, std::string_view(kv), offset); }, [&](KeyValueType::Null) { Null(tagName); }, + [&](KeyValueType::Bool) { Put(tagName, bool(kv), offset); }, + [&](KeyValueType::Tuple) { + auto arrNode = Array(tagName); + for (auto& val : kv.getCompositeValues()) { + arrNode.Put(nullptr, val); + } + }, + [&](KeyValueType::Uuid) { Put(tagName, Uuid{kv}, offset); }, + [](OneOf) noexcept { assertrx_throw(false); }); return *this; } diff --git a/cpp_src/core/cjson/cjsonbuilder.h b/cpp_src/core/cjson/cjsonbuilder.h index 4ecd4a7bf..10b543d5f 100644 --- a/cpp_src/core/cjson/cjsonbuilder.h +++ b/cpp_src/core/cjson/cjsonbuilder.h @@ -1,7 +1,7 @@ #pragma once #include "core/keyvalue/p_string.h" -#include "estl/span.h" +#include #include "objtype.h" #include "tagsmatcher.h" @@ -34,42 +34,49 @@ class CJsonBuilder { } CJsonBuilder Object(std::nullptr_t) { return Object(0); } - void Array(int tagName, span data, int /*offset*/ = 0) { + void Array(int tagName, std::span data, int /*offset*/ = 0) { ser_->PutCTag(ctag{TAG_ARRAY, tagName}); ser_->PutCArrayTag(carraytag(data.size(), TAG_STRING)); for (auto d : data) { ser_->PutVString(d); } } - void Array(int tagName, span data, int offset = 0); - void Array(int tagName, span data, int /*offset*/ = 0) { + void Array(int tagName, std::span data, int offset = 0); + void Array(int tagName, std::span data, int /*offset*/ = 0) { ser_->PutCTag(ctag{TAG_ARRAY, tagName}); ser_->PutCArrayTag(carraytag(data.size(), TAG_VARINT)); for (auto d : data) { ser_->PutVarint(d); } } - void Array(int tagName, span data, int /*offset*/ = 0) { + void Array(int tagName, std::span data, int /*offset*/ = 0) { ser_->PutCTag(ctag{TAG_ARRAY, tagName}); ser_->PutCArrayTag(carraytag(data.size(), TAG_VARINT)); for (auto d : data) { ser_->PutVarint(d); } } - void Array(int tagName, span data, int /*offset*/ = 0) { + void Array(int tagName, std::span data, int /*offset*/ = 0) { ser_->PutCTag(ctag{TAG_ARRAY, tagName}); ser_->PutCArrayTag(carraytag(data.size(), TAG_BOOL)); for (auto d : data) { ser_->PutBool(d); } } - void Array(int tagName, span data, int /*offset*/ = 0) { + void Array(int tagName, std::span data, int /*offset*/ = 0) { ser_->PutCTag(ctag{TAG_ARRAY, tagName}); ser_->PutCArrayTag(carraytag(data.size(), TAG_DOUBLE)); for (auto d : data) { ser_->PutDouble(d); } } + void Array(int tagName, std::span data, int /*offset*/ = 0) { + ser_->PutCTag(ctag{TAG_ARRAY, tagName}); + ser_->PutCArrayTag(carraytag(data.size(), TAG_FLOAT)); + for (auto d : data) { + ser_->PutFloat(d); + } + } void Array(int tagName, Serializer& ser, TagType tagType, int count) { ser_->PutCTag(ctag{TAG_ARRAY, tagName}); ser_->PutCArrayTag(carraytag(count, tagType)); @@ -91,9 +98,10 @@ class CJsonBuilder { CJsonBuilder& Put(int tagName, int arg, int offset = 0); CJsonBuilder& Put(int tagName, int64_t arg, int offset = 0); CJsonBuilder& Put(int tagName, double arg, int offset = 0); + CJsonBuilder& Put(int tagName, float arg, int offset = 0); CJsonBuilder& Put(int tagName, std::string_view arg, int offset = 0); CJsonBuilder& Put(int tagName, Uuid arg, int offset = 0); - CJsonBuilder& Ref(int tagName, const Variant& v, int field); + CJsonBuilder& Ref(int tagName, const KeyValueType& type, int field); CJsonBuilder& ArrayRef(int tagName, int field, int count); CJsonBuilder& Null(int tagName); CJsonBuilder& Put(int tagName, const Variant& kv, int offset = 0); diff --git a/cpp_src/core/cjson/cjsondecoder.cc b/cpp_src/core/cjson/cjsondecoder.cc index f7c532e8b..ede5dbac2 100644 --- a/cpp_src/core/cjson/cjsondecoder.cc +++ b/cpp_src/core/cjson/cjsondecoder.cc @@ -7,7 +7,9 @@ namespace reindexer { template -bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrser, FilterT filter, RecoderT recoder, TagOptT) { +bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrser, FilterT filter, RecoderT recoder, TagOptT, + FloatVectorsHolderVector& floatVectorsHolder) { + using namespace std::string_view_literals; const ctag tag = rdser.GetCTag(); TagType tagType = tag.Type(); if (tag == kCTagEnd) { @@ -42,24 +44,54 @@ bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrs if (tagType == TAG_ARRAY) { const carraytag atag = rdser.GetCArrayTag(); const auto count = atag.Count(); - if rx_unlikely (!fieldRef.IsArray()) { - throwUnexpectedArrayError(fieldRef); - } - validateArrayFieldRestrictions(fieldRef, count, "cjson"); - const int ofs = pl.ResizeArray(field, count, true); const TagType atagType = atag.Type(); - if (atagType != TAG_OBJECT) { - for (size_t i = 0; i < count; ++i) { - pl.Set(field, ofs + i, cjsonValueToVariant(atagType, rdser, fieldType)); + if (fieldRef.IsFloatVector()) { + ConstFloatVectorView vectView; + if (count != 0) { + if (atagType != TAG_DOUBLE && atagType != TAG_FLOAT && atagType != TAG_VARINT) { + throwUnexpectedArrayTypeForFloatVectorError("cjson"sv, fieldRef); + } + if (count != size_t(fieldRef.FloatVectorDimension())) { + throwUnexpectedArraySizeForFloatVectorError("cjson"sv, fieldRef, count); + } + auto vect = FloatVector::CreateNotInitialized(fieldRef.FloatVectorDimension()); + if (atagType == TAG_DOUBLE) { + for (size_t i = 0; i < count; ++i) { + vect.RawData()[i] = rdser.GetDouble(); + } + } else if (atagType == TAG_FLOAT) { + for (size_t i = 0; i < count; ++i) { + vect.RawData()[i] = rdser.GetFloat(); + } + } else if (atagType == TAG_VARINT) { + for (size_t i = 0; i < count; ++i) { + vect.RawData()[i] = rdser.GetVarint(); + } + } + floatVectorsHolder.Add(std::move(vect)); + vectView = floatVectorsHolder.Back(); } - } else { - for (size_t i = 0; i < count; ++i) { - pl.Set(field, ofs + i, cjsonValueToVariant(rdser.GetCTag().Type(), rdser, fieldType)); + objectScalarIndexes_.set(field); // Indexed float vector is treated as scalar value + pl.Set(field, Variant{vectView}); + wrser.PutCTag(ctag{TAG_ARRAY, tagName, field}); + wrser.PutVarUint(count); + } else if rx_likely (fieldRef.IsArray()) { + validateArrayFieldRestrictions(fieldRef, count, "cjson"); + const int ofs = pl.ResizeArray(field, count, true); + if (atagType != TAG_OBJECT) { + for (size_t i = 0; i < count; ++i) { + pl.Set(field, ofs + i, cjsonValueToVariant(atagType, rdser, fieldType)); + } + } else { + for (size_t i = 0; i < count; ++i) { + pl.Set(field, ofs + i, cjsonValueToVariant(rdser.GetCTag().Type(), rdser, fieldType)); + } } + wrser.PutCTag(ctag{TAG_ARRAY, tagName, field}); + wrser.PutVarUint(count); + } else { + throwUnexpectedArrayError("cjson"sv, fieldRef); } - - wrser.PutCTag(ctag{TAG_ARRAY, tagName, field}); - wrser.PutVarUint(count); } else { validateNonArrayFieldRestrictions(objectScalarIndexes_, pl, fieldRef, field, isInArray(), "cjson"); validateArrayFieldRestrictions(fieldRef, 1, "cjson"); @@ -67,10 +99,11 @@ bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrs pl.Set(field, cjsonValueToVariant(tagType, rdser, fieldType), true); fieldType.EvaluateOneOf( [&](OneOf) { wrser.PutCTag(ctag{TAG_VARINT, tagName, field}); }, - [&](OneOf) { - wrser.PutCTag(ctag{fieldType.ToTagType(), tagName, field}); - }, - [&](OneOf) { assertrx(false); }); + [&](OneOf) { wrser.PutCTag(ctag{fieldType.ToTagType(), tagName, field}); }, + [&](OneOf) { + assertrx(false); + }); } } } else { @@ -83,7 +116,7 @@ bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrs tagType = recoder.RegisterTagType(tagType, tagsPath_); wrser.PutCTag(ctag{tagType, tagName, field}); if (tagType == TAG_OBJECT) { - while (decodeCJson(pl, rdser, wrser, filter.MakeCleanCopy(), recoder.MakeCleanCopy(), NamedTagOpt{})); + while (decodeCJson(pl, rdser, wrser, filter.MakeCleanCopy(), recoder.MakeCleanCopy(), NamedTagOpt{}, floatVectorsHolder)); } else if (recoder.Recode(rdser, wrser)) { // No more actions needed after recoding } else if (tagType == TAG_ARRAY) { @@ -94,7 +127,8 @@ bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrs CounterGuardIR32 g(arrayLevel_); if (atagType == TAG_OBJECT) { for (size_t i = 0; i < count; ++i) { - decodeCJson(pl, rdser, wrser, filter.MakeCleanCopy(), recoder.MakeCleanCopy(), NamelessTagOpt{}); + decodeCJson(pl, rdser, wrser, filter.MakeCleanCopy(), recoder.MakeCleanCopy(), NamelessTagOpt{}, + floatVectorsHolder); } } else { for (size_t i = 0; i < count; ++i) { @@ -110,7 +144,7 @@ bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrs } else { // !match wrser.PutCTag(ctag{tagType, tagName, field}); - while (decodeCJson(pl, rdser, wrser, filter.MakeSkipFilter(), recoder.MakeCleanCopy(), NamedTagOpt{})); + while (decodeCJson(pl, rdser, wrser, filter.MakeSkipFilter(), recoder.MakeCleanCopy(), NamedTagOpt{}, floatVectorsHolder)); } } @@ -124,7 +158,7 @@ bool CJsonDecoder::decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrs [[nodiscard]] Variant CJsonDecoder::cjsonValueToVariant(TagType tagType, Serializer& rdser, KeyValueType fieldType) { if (fieldType.Is() && tagType != TagType::TAG_STRING) { auto& back = storage_.emplace_back(rdser.GetRawVariant(KeyValueType{tagType}).As()); - return Variant(p_string(back), Variant::no_hold_t{}); + return Variant(p_string(back), Variant::noHold); } else { return reindexer::cjsonValueToVariant(tagType, rdser, fieldType); } @@ -135,17 +169,17 @@ RX_NO_INLINE void CJsonDecoder::throwTagReferenceError(ctag tag, const Payload& tagsMatcher_.tag2name(tag.Name()), pl.Type().Name()); } -RX_NO_INLINE void CJsonDecoder::throwUnexpectedArrayError(const PayloadFieldType& fieldRef) { - throw Error(errLogic, "Error parsing cjson field '%s' - got array, expected scalar %s", fieldRef.Name(), fieldRef.Type().Name()); -} - -template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::DummyRecoder, CJsonDecoder::NamelessTagOpt); template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt); -template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::DummyRecoder, CJsonDecoder::NamelessTagOpt); + Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); +template bool CJsonDecoder::decodeCJson( + Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::CustomRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt); + Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); +template bool CJsonDecoder::decodeCJson( + Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::CustomRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); } // namespace reindexer diff --git a/cpp_src/core/cjson/cjsondecoder.h b/cpp_src/core/cjson/cjsondecoder.h index 7dc4ed038..d44a706b6 100644 --- a/cpp_src/core/cjson/cjsondecoder.h +++ b/cpp_src/core/cjson/cjsondecoder.h @@ -1,10 +1,11 @@ #pragma once -#include #include "core/cjson/tagspath.h" +#include "core/keyvalue/float_vectors_holder.h" #include "core/payload/fieldsset.h" #include "core/payload/payloadiface.h" #include "core/type_consts.h" +#include "recoder.h" namespace reindexer { @@ -12,16 +13,6 @@ class TagsMatcher; class Serializer; class WrSerializer; -class Recoder { -public: - [[nodiscard]] virtual TagType Type(TagType oldTagType) = 0; - virtual void Recode(Serializer&, WrSerializer&) const = 0; - virtual void Recode(Serializer&, Payload&, int tagName, WrSerializer&) = 0; - [[nodiscard]] virtual bool Match(int field) const noexcept = 0; - [[nodiscard]] virtual bool Match(const TagsPath&) const = 0; - virtual ~Recoder() = default; -}; - class CJsonDecoder { public: using StrHolderT = h_vector; @@ -88,57 +79,64 @@ class CJsonDecoder { bool match_{false}; }; - class DummyRecoder { + class DefaultRecoder { public: - RX_ALWAYS_INLINE DummyRecoder MakeCleanCopy() const noexcept { return DummyRecoder(); } - RX_ALWAYS_INLINE bool Recode(Serializer&, WrSerializer&) const noexcept { return false; } + static RX_ALWAYS_INLINE DefaultRecoder MakeCleanCopy() noexcept { return DefaultRecoder(); } + RX_ALWAYS_INLINE bool Recode(Serializer&, WrSerializer&) const { return false; } RX_ALWAYS_INLINE bool Recode(Serializer&, Payload&, int, WrSerializer&) const noexcept { return false; } - RX_ALWAYS_INLINE TagType RegisterTagType(TagType tagType, int) const noexcept { return tagType; } - RX_ALWAYS_INLINE TagType RegisterTagType(TagType tagType, const TagsPath&) const noexcept { return tagType; } + RX_ALWAYS_INLINE TagType RegisterTagType(TagType tagType, int) const noexcept { + // Do not recode index field + return tagType; + } + RX_ALWAYS_INLINE TagType RegisterTagType(TagType tagType, const TagsPath&) noexcept { return tagType; } }; - class DefaultRecoder { - public: - DefaultRecoder(Recoder& r) noexcept : r_(&r), needToRecode_(false) {} - RX_ALWAYS_INLINE DefaultRecoder MakeCleanCopy() const noexcept { return DefaultRecoder(*r_); } + class CustomRecoder { + public: + CustomRecoder(Recoder& r) noexcept : r_(&r), needToRecode_(false) {} + RX_ALWAYS_INLINE CustomRecoder MakeCleanCopy() const noexcept { return CustomRecoder(*r_); } RX_ALWAYS_INLINE bool Recode(Serializer& ser, WrSerializer& wser) const { if (needToRecode_) { r_->Recode(ser, wser); + return true; } - return needToRecode_; + return defaultRecoder_.Recode(ser, wser); } RX_ALWAYS_INLINE bool Recode(Serializer& ser, Payload& pl, int tagName, WrSerializer& wser) const { if (needToRecode_) { r_->Recode(ser, pl, tagName, wser); + return true; } - return needToRecode_; + return defaultRecoder_.Recode(ser, pl, tagName, wser); } RX_ALWAYS_INLINE TagType RegisterTagType(TagType tagType, int field) { needToRecode_ = r_->Match(field); - return needToRecode_ ? r_->Type(tagType) : tagType; + return needToRecode_ ? r_->Type(tagType) : defaultRecoder_.RegisterTagType(tagType, field); } RX_ALWAYS_INLINE TagType RegisterTagType(TagType tagType, const TagsPath& tagsPath) { needToRecode_ = r_->Match(tagsPath); - return needToRecode_ ? r_->Type(tagType) : tagType; + return needToRecode_ ? r_->Type(tagType) : defaultRecoder_.RegisterTagType(tagType, tagsPath); } private: + DefaultRecoder defaultRecoder_; Recoder* r_{nullptr}; bool needToRecode_{false}; }; struct NamedTagOpt {}; struct NamelessTagOpt {}; - template - void Decode(Payload& pl, Serializer& rdSer, WrSerializer& wrSer, FilterT filter = FilterT(), RecoderT recoder = RecoderT()) { + template + void Decode(Payload& pl, Serializer& rdSer, WrSerializer& wrSer, FloatVectorsHolderVector& floatVectorsHolder, + FilterT filter = FilterT(), RecoderT recoder = RecoderT()) { static_assert(std::is_same_v || std::is_same_v, "Other filter types are not allowed for the public API"); - static_assert(std::is_same_v || std::is_same_v, + static_assert(std::is_same_v || std::is_same_v, "Other recoder types are not allowed for the public API"); objectScalarIndexes_.reset(); if rx_likely (!filter.HasArraysFields(pl.Type())) { - decodeCJson(pl, rdSer, wrSer, filter, recoder, NamelessTagOpt{}); + decodeCJson(pl, rdSer, wrSer, filter, recoder, NamelessTagOpt{}, floatVectorsHolder); return; } #ifdef RX_WITH_STDLIB_DEBUG @@ -148,16 +146,16 @@ class CJsonDecoder { // Possible implementation has noticeable negative impact on 'FromCJSONPKOnly' benchmark. // Currently, we are using filter for PKs only, and PKs can not be arrays, so this code actually will never be called at the // current moment - decodeCJson(pl, rdSer, wrSer, DummyFilter(), recoder, NamelessTagOpt{}); + decodeCJson(pl, rdSer, wrSer, DummyFilter(), recoder, NamelessTagOpt{}, floatVectorsHolder); #endif // RX_WITH_STDLIB_DEBUG } private: template - bool decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrser, FilterT filter, RecoderT recoder, TagOptT); + bool decodeCJson(Payload& pl, Serializer& rdser, WrSerializer& wrser, FilterT filter, RecoderT recoder, TagOptT, + FloatVectorsHolderVector&); bool isInArray() const noexcept { return arrayLevel_ > 0; } [[noreturn]] void throwTagReferenceError(ctag, const Payload&); - [[noreturn]] void throwUnexpectedArrayError(const PayloadFieldType&); [[nodiscard]] Variant cjsonValueToVariant(TagType tag, Serializer& rdser, KeyValueType dstType); @@ -169,13 +167,17 @@ class CJsonDecoder { StrHolderT& storage_; }; -extern template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::DummyRecoder, CJsonDecoder::NamelessTagOpt); extern template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt); -extern template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::DummyRecoder, CJsonDecoder::NamelessTagOpt); + Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); +extern template bool CJsonDecoder::decodeCJson( + Payload&, Serializer&, WrSerializer&, CJsonDecoder::DummyFilter, CJsonDecoder::CustomRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); extern template bool CJsonDecoder::decodeCJson( - Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt); + Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::DefaultRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); +extern template bool CJsonDecoder::decodeCJson( + Payload&, Serializer&, WrSerializer&, CJsonDecoder::RestrictingFilter, CJsonDecoder::CustomRecoder, CJsonDecoder::NamelessTagOpt, + FloatVectorsHolderVector&); } // namespace reindexer diff --git a/cpp_src/core/cjson/cjsonmodifier.cc b/cpp_src/core/cjson/cjsonmodifier.cc index e28579831..643435d6f 100644 --- a/cpp_src/core/cjson/cjsonmodifier.cc +++ b/cpp_src/core/cjson/cjsonmodifier.cc @@ -1,6 +1,5 @@ #include "cjsonmodifier.h" #include "cjsontools.h" -#include "core/type_consts_helpers.h" #include "jsondecoder.h" #include "tagsmatcher.h" #include "tools/serializer.h" @@ -12,8 +11,8 @@ const std::string_view kWrongFieldsAmountMsg = "Number of fields for update shou class CJsonModifier::Context { public: Context(const IndexedTagsPath& fieldPath, const VariantArray& v, WrSerializer& ser, std::string_view tuple, FieldModifyMode m, - const Payload* pl = nullptr) - : value(v), wrser(ser), rdser(tuple), mode(m), payload(pl) { + FloatVectorsHolderVector& fvHolder, const Payload* pl = nullptr) + : value(v), wrser(ser), rdser(tuple), mode(m), payload(pl), floatVectorsHolder(fvHolder) { jsonPath.reserve(fieldPath.size()); for (const IndexedPathNode& node : fieldPath) { isForAllItems_ = isForAllItems_ || node.IsForAllItems(); @@ -39,14 +38,15 @@ class CJsonModifier::Context { bool updateArrayElements = false; const Payload* payload = nullptr; std::array fieldsArrayOffsets; + FloatVectorsHolderVector& floatVectorsHolder; private: bool isForAllItems_ = false; }; void CJsonModifier::SetFieldValue(std::string_view tuple, const IndexedTagsPath& fieldPath, const VariantArray& val, WrSerializer& ser, - const Payload& pl) { - auto ctx = initState(tuple, fieldPath, val, ser, &pl, FieldModifyMode::FieldModeSet); + const Payload& pl, FloatVectorsHolderVector& floatVectorsHolder) { + auto ctx = initState(tuple, fieldPath, val, ser, &pl, FieldModifyMode::FieldModeSet, floatVectorsHolder); updateFieldInTuple(ctx); if (!ctx.fieldUpdated && !ctx.IsForAllItems()) { throw Error(errParams, "[SetFieldValue] Requested field or array's index was not found"); @@ -54,8 +54,8 @@ void CJsonModifier::SetFieldValue(std::string_view tuple, const IndexedTagsPath& } void CJsonModifier::SetObject(std::string_view tuple, const IndexedTagsPath& fieldPath, const VariantArray& val, WrSerializer& ser, - const Payload& pl) { - auto ctx = initState(tuple, fieldPath, val, ser, &pl, FieldModifyMode::FieldModeSetJson); + const Payload& pl, FloatVectorsHolderVector& floatVectorsHolder) { + auto ctx = initState(tuple, fieldPath, val, ser, &pl, FieldModifyMode::FieldModeSetJson, floatVectorsHolder); buildCJSON(ctx); if (!ctx.fieldUpdated && !ctx.IsForAllItems()) { throw Error(errParams, "[SetObject] Requested field or array's index was not found"); @@ -63,17 +63,19 @@ void CJsonModifier::SetObject(std::string_view tuple, const IndexedTagsPath& fie } void CJsonModifier::RemoveField(std::string_view tuple, const IndexedTagsPath& fieldPath, WrSerializer& wrser) { - auto ctx = initState(tuple, fieldPath, {}, wrser, nullptr, FieldModeDrop); + thread_local FloatVectorsHolderVector floatVectorsHolder; + auto ctx = initState(tuple, fieldPath, {}, wrser, nullptr, FieldModeDrop, floatVectorsHolder); dropFieldInTuple(ctx); } CJsonModifier::Context CJsonModifier::initState(std::string_view tuple, const IndexedTagsPath& fieldPath, const VariantArray& val, - WrSerializer& ser, const Payload* pl, FieldModifyMode mode) { + WrSerializer& ser, const Payload* pl, FieldModifyMode mode, + FloatVectorsHolderVector& floatVectorsHolder) { if (fieldPath.empty()) { throw Error(errLogic, kWrongFieldsAmountMsg); } tagsPath_.clear(); - Context ctx(fieldPath, val, ser, tuple, mode, pl); + Context ctx(fieldPath, val, ser, tuple, mode, floatVectorsHolder, pl); fieldPath_ = fieldPath; return ctx; @@ -86,14 +88,14 @@ void CJsonModifier::updateObject(Context& ctx, int tagName) const { CJsonBuilder cjsonBuilder(ctx.wrser, ObjType::TypeArray, &tagsMatcher_, tagName); for (const auto& item : ctx.value) { auto objBuilder = cjsonBuilder.Object(nullptr); - jsonDecoder.Decode(std::string_view(item), objBuilder, ctx.jsonPath); + jsonDecoder.Decode(std::string_view(item), objBuilder, ctx.jsonPath, ctx.floatVectorsHolder); } return; } assertrx(ctx.value.size() == 1); CJsonBuilder cjsonBuilder(ctx.wrser, ObjType::TypeObject, &tagsMatcher_, tagName); - jsonDecoder.Decode(std::string_view(ctx.value.front()), cjsonBuilder, ctx.jsonPath); + jsonDecoder.Decode(std::string_view(ctx.value.front()), cjsonBuilder, ctx.jsonPath, ctx.floatVectorsHolder); } void CJsonModifier::insertField(Context& ctx) const { @@ -205,7 +207,7 @@ void CJsonModifier::writeCTag(const ctag& tag, Context& ctx) { const int field = tag.Field(); const int tagName = tag.Name(); if (tagType == TAG_ARRAY) { - const auto count = ctx.rdser.GetVarUint(); + const auto count = ctx.rdser.GetVarUInt(); if (!tagMatched || !ctx.fieldUpdated) { auto& lastTag = tagsPath_.back(); for (uint64_t i = 0; i < count; ++i) { @@ -289,6 +291,7 @@ void CJsonModifier::updateArray(TagType atagType, uint32_t count, int tagName, C case TAG_BOOL: case TAG_END: case TAG_UUID: + case TAG_FLOAT: // array tag type updated (need store as object) ctx.wrser.PutCTag(ctag{atagType}); copyCJsonValue(atagType, ctx.rdser, ctx.wrser); @@ -360,6 +363,7 @@ void CJsonModifier::copyArray(int tagName, Context& ctx) { case TAG_BOOL: case TAG_END: case TAG_UUID: + case TAG_FLOAT: copyCJsonValue(atagType, ctx.rdser, ctx.wrser); break; } @@ -444,7 +448,7 @@ bool CJsonModifier::dropFieldInTuple(Context& ctx) { if (isIndexed(field)) { if (tagType == TAG_ARRAY) { - const auto count = ctx.rdser.GetVarUint(); + const auto count = ctx.rdser.GetVarUInt(); ctx.wrser.PutVarUint(count); } return true; @@ -489,6 +493,7 @@ bool CJsonModifier::dropFieldInTuple(Context& ctx) { case TAG_NULL: case TAG_END: case TAG_UUID: + case TAG_FLOAT: copyCJsonValue(atagType, ctx.rdser, ctx.wrser); break; } @@ -546,40 +551,61 @@ bool CJsonModifier::buildCJSON(Context& ctx) { const auto field = tag.Field(); if (tagType == TAG_ARRAY) { - const carraytag atag{isIndexed(field) ? carraytag(ctx.rdser.GetVarUint(), pt_.Field(tag.Field()).Type().ToTagType()) - : ctx.rdser.GetCArrayTag()}; - ctx.wrser.PutCArrayTag(atag); - const auto arrSize = atag.Count(); - for (size_t i = 0; i < arrSize; ++i) { - tagsPath_.back().SetIndex(i); - tagMatched = fieldPath_.Compare(tagsPath_); - if (tagMatched) { - updateObject(ctx, 0); - skipCjsonTag(ctx.rdser.GetCTag(), ctx.rdser, &ctx.fieldsArrayOffsets); - continue; + const bool isIndexedField = isIndexed(field); + bool isFloatVector = isIndexedField && pt_.Field(tag.Field()).Type().Is(); + if (isFloatVector) { + // Embed float vector into CJSON as float array + const carraytag atag(ctx.rdser.GetVarUInt(), TAG_FLOAT); + ctx.wrser.PutCArrayTag(atag); + assertrx_dbg(ctx.fieldsArrayOffsets[field] == 0); + auto value = ctx.payload->Get(field, ctx.fieldsArrayOffsets[field]); + const auto view = ConstFloatVectorView(value); + if rx_unlikely (view.IsStripped()) { + throw Error(errLogic, "CJsonModifier: Attempt to serialize stripped vector into CJSON"); } + const auto span = view.Span(); + assertrx_dbg(span.size() == atag.Count()); + for (float v : span) { + ctx.wrser.PutFloat(v); + } + ctx.fieldsArrayOffsets[field] += 1; + } else { + const carraytag atag{isIndexedField ? carraytag(ctx.rdser.GetVarUInt(), pt_.Field(tag.Field()).Type().ToTagType()) + : ctx.rdser.GetCArrayTag()}; + ctx.wrser.PutCArrayTag(atag); + const auto arrSize = atag.Count(); + for (size_t i = 0; i < arrSize; ++i) { + tagsPath_.back().SetIndex(i); + tagMatched = fieldPath_.Compare(tagsPath_); + if (tagMatched) { + updateObject(ctx, 0); + skipCjsonTag(ctx.rdser.GetCTag(), ctx.rdser, &ctx.fieldsArrayOffsets); + continue; + } - switch (atag.Type()) { - case TAG_OBJECT: { - TagsPathScope pathScopeObj(ctx.currObjPath, tagName); - buildCJSON(ctx); - break; + switch (atag.Type()) { + case TAG_OBJECT: { + TagsPathScope pathScopeObj(ctx.currObjPath, tagName); + buildCJSON(ctx); + break; + } + case TAG_VARINT: + case TAG_DOUBLE: + case TAG_STRING: + case TAG_BOOL: + case TAG_ARRAY: + case TAG_NULL: + case TAG_END: + case TAG_UUID: + case TAG_FLOAT: + embedFieldValue(atag.Type(), field, ctx, i); + break; } - case TAG_VARINT: - case TAG_DOUBLE: - case TAG_STRING: - case TAG_BOOL: - case TAG_ARRAY: - case TAG_NULL: - case TAG_END: - case TAG_UUID: - embedFieldValue(atag.Type(), field, ctx, i); - break; } - } - if (isIndexed(field)) { - ctx.fieldsArrayOffsets[field] += arrSize; + if (isIndexed(field)) { + ctx.fieldsArrayOffsets[field] += arrSize; + } } return true; } diff --git a/cpp_src/core/cjson/cjsonmodifier.h b/cpp_src/core/cjson/cjsonmodifier.h index c907b437c..2d7ec0ffc 100644 --- a/cpp_src/core/cjson/cjsonmodifier.h +++ b/cpp_src/core/cjson/cjsonmodifier.h @@ -1,6 +1,7 @@ #pragma once #include "core/cjson/tagspath.h" +#include "core/keyvalue/float_vectors_holder.h" #include "core/payload/payloadiface.h" #include "core/payload/payloadtype.h" @@ -12,14 +13,15 @@ class CJsonModifier { public: CJsonModifier(TagsMatcher& tagsMatcher, PayloadType pt) noexcept : pt_(std::move(pt)), tagsMatcher_(tagsMatcher) {} void SetFieldValue(std::string_view tuple, const IndexedTagsPath& fieldPath, const VariantArray& val, WrSerializer& ser, - const Payload& pl); - void SetObject(std::string_view tuple, const IndexedTagsPath& fieldPath, const VariantArray& val, WrSerializer& ser, const Payload& pl); + const Payload& pl, FloatVectorsHolderVector&); + void SetObject(std::string_view tuple, const IndexedTagsPath& fieldPath, const VariantArray& val, WrSerializer& ser, const Payload& pl, + FloatVectorsHolderVector&); void RemoveField(std::string_view tuple, const IndexedTagsPath& fieldPath, WrSerializer& wrser); private: class Context; Context initState(std::string_view tuple, const IndexedTagsPath& fieldPath, const VariantArray& val, WrSerializer& ser, - const Payload* pl, FieldModifyMode mode); + const Payload* pl, FieldModifyMode mode, FloatVectorsHolderVector&); bool updateFieldInTuple(Context& ctx); bool dropFieldInTuple(Context& ctx); bool buildCJSON(Context& ctx); diff --git a/cpp_src/core/cjson/cjsontools.cc b/cpp_src/core/cjson/cjsontools.cc index f08e4d8b0..f08432c98 100644 --- a/cpp_src/core/cjson/cjsontools.cc +++ b/cpp_src/core/cjson/cjsontools.cc @@ -31,12 +31,12 @@ void copyCJsonValue(TagType tagType, const Variant& value, WrSerializer& wrser) wrser.PutDouble(static_cast(value.convert(KeyValueType::Double{}))); break; case TAG_VARINT: - value.Type().EvaluateOneOf([&](KeyValueType::Int) { wrser.PutVarint(value.As()); }, - [&](KeyValueType::Int64) { wrser.PutVarint(value.As()); }, - [&](OneOf) { - wrser.PutVarint(static_cast(value.convert(KeyValueType::Int64{}))); - }); + value.Type().EvaluateOneOf( + [&](KeyValueType::Int) { wrser.PutVarint(value.As()); }, + [&](KeyValueType::Int64) { wrser.PutVarint(value.As()); }, + [&](OneOf) { wrser.PutVarint(static_cast(value.convert(KeyValueType::Int64{}))); }); break; case TAG_BOOL: wrser.PutBool(static_cast(value.convert(KeyValueType::Bool{}))); @@ -47,6 +47,9 @@ void copyCJsonValue(TagType tagType, const Variant& value, WrSerializer& wrser) case TAG_UUID: wrser.PutUuid(value.convert(KeyValueType::Uuid{}).As()); break; + case TAG_FLOAT: + wrser.PutFloat(static_cast(value.convert(KeyValueType::Float{}))); + break; case TAG_NULL: break; case TAG_OBJECT: @@ -110,6 +113,9 @@ void copyCJsonValue(TagType tagType, Serializer& rdser, WrSerializer& wrser) { case TAG_UUID: wrser.PutUuid(rdser.GetUuid()); break; + case TAG_FLOAT: + wrser.PutFloat(rdser.GetFloat()); + break; case TAG_OBJECT: wrser.PutVariant(rdser.GetVariant()); break; @@ -138,7 +144,7 @@ void skipCjsonTag(ctag tag, Serializer& rdser, std::array } } } else { - const auto len = rdser.GetVarUint(); + const auto len = rdser.GetVarUInt(); if (fieldsArrayOffsets) { (*fieldsArrayOffsets)[field] += len; } @@ -155,7 +161,8 @@ void skipCjsonTag(ctag tag, Serializer& rdser, std::array case TAG_END: case TAG_BOOL: case TAG_NULL: - case TAG_UUID: { + case TAG_UUID: + case TAG_FLOAT: { const auto field = tag.Field(); const bool embeddedField = (field < 0); if (embeddedField) { @@ -187,10 +194,14 @@ void buildPayloadTuple(const PayloadIface& pl, const TagsMatcher* tagsMatcher const int tagName = tagsMatcher->name2tag(fieldType.JsonPaths()[0]); assertf(tagName != 0, "ns=%s, field=%s", pl.Type().Name(), fieldType.JsonPaths()[0]); - if (fieldType.IsArray()) { + if (fieldType.IsFloatVector()) { + const auto value = pl.Get(field, 0); + const auto count = ConstFloatVectorView(value).Dimension(); + builder.ArrayRef(tagName, field, int(count)); + } else if (fieldType.IsArray()) { builder.ArrayRef(tagName, field, pl.GetArrayLen(field)); } else { - builder.Ref(tagName, pl.Get(field, 0), field); + builder.Ref(tagName, pl.Get(field, 0).Type(), field); } } } @@ -207,9 +218,52 @@ void throwScalarMultipleEncodesError(const Payload& pl, const PayloadFieldType& throw Error(errLogic, "Non-array field '%s' [%d] from '%s' can only be encoded once.", f.Name(), field, pl.Type().Name()); } +void throwUnexpectedArrayError(std::string_view parserName, const PayloadFieldType& fieldRef) { + throw Error(errLogic, "Error parsing %s field '%s' - got array, expected scalar %s", parserName, fieldRef.Name(), + fieldRef.Type().Name()); +} + +void throwUnexpectedArraySizeForFloatVectorError(std::string_view parserName, const PayloadFieldType& fieldRef, size_t size) { + throw Error(errLogic, "Error parsing %s field '%s' - got array of size %d, expected float_vector of size %d", parserName, + fieldRef.Name(), size, int(fieldRef.FloatVectorDimension())); +} + +void throwUnexpectedArrayTypeForFloatVectorError(std::string_view parserName, const PayloadFieldType& fieldRef) { + throw Error(errLogic, "Error parsing %s field '%s' - got array of non-double values, expected array convertible to %s", parserName, + fieldRef.Name(), fieldRef.Type().Name()); +} + void throwUnexpectedArraySizeError(std::string_view parserName, const PayloadFieldType& f, int arraySize) { throw Error(errParams, "%s array field '%s' for this index type must contain %d elements, but got %d", parserName, f.Name(), - f.ArrayDim(), arraySize); + f.ArrayDims(), arraySize); +} + +static void skipCjsonValue(TagType type, Serializer& cjson) { + switch (type) { + case TAG_VARINT: + cjson.GetVarint(); + break; + case TAG_DOUBLE: + cjson.GetDouble(); + break; + case TAG_STRING: + cjson.GetPVString(); + break; + case TAG_BOOL: + cjson.GetVarUInt(); + break; + case TAG_UUID: + cjson.GetUuid(); + break; + case TAG_FLOAT: + cjson.GetFloat(); + break; + case TAG_NULL: + case TAG_OBJECT: + case TAG_ARRAY: + case TAG_END: + assertrx(0); + } } static void dumpCjsonValue(TagType type, Serializer& cjson, std::ostream& dump) { @@ -224,11 +278,14 @@ static void dumpCjsonValue(TagType type, Serializer& cjson, std::ostream& dump) dump << '"' << std::string_view{cjson.GetPVString()} << '"'; break; case TAG_BOOL: - dump << std::boolalpha << bool(cjson.GetVarUint()); + dump << std::boolalpha << bool(cjson.GetVarUInt()); break; case TAG_UUID: dump << std::string{cjson.GetUuid()}; break; + case TAG_FLOAT: + dump << cjson.GetFloat(); + break; case TAG_NULL: case TAG_OBJECT: case TAG_ARRAY: @@ -240,6 +297,7 @@ static void dumpCjsonValue(TagType type, Serializer& cjson, std::ostream& dump) template static void dumpCjsonObject(Serializer& cjson, std::ostream& dump, const TagsMatcher* tm, const PL* pl, std::string_view tab, unsigned indentLevel) { + static constexpr size_t kMaxArrayOutput = 3; const auto indent = [&dump, tab](unsigned indLvl) { for (unsigned i = 0; i < indLvl; ++i) { dump << tab; @@ -266,6 +324,7 @@ static void dumpCjsonObject(Serializer& cjson, std::ostream& dump, const TagsMat switch (type) { case TAG_VARINT: case TAG_DOUBLE: + case TAG_FLOAT: case TAG_STRING: case TAG_BOOL: case TAG_UUID: @@ -279,20 +338,50 @@ static void dumpCjsonObject(Serializer& cjson, std::ostream& dump, const TagsMat case TAG_ARRAY: { dump << '\n'; indent(indentLevel + 1); - const size_t count = cjson.GetVarUint(); + const size_t count = cjson.GetVarUInt(); dump << "Count: " << count; if (pl) { dump << " -> ["; buf.clear(); pl->Get(field, buf); - assertrx(buf.size() == count); - for (size_t i = 0; i < count; ++i) { - if (i != 0) { - dump << ", "; + if (pl->Type().Field(field).IsFloatVector()) { + assertrx(buf.size() == 1); + const ConstFloatVectorView vect{buf[0]}; + if (vect.IsEmpty()) { + dump << " -> "; + } else { + dump << " -> " << size_t(vect.Dimension()); + if (vect.IsStripped()) { + dump << "[]"; + } else { + dump << '['; + for (size_t i = 0; i < std::min(size_t(vect.Dimension()), kMaxArrayOutput); ++i) { + if (i != 0) { + dump << ", "; + } + dump << vect.Data()[i]; + } + if (size_t(vect.Dimension()) > kMaxArrayOutput) { + dump << ", ...]"; + } else { + dump << ']'; + } + } + } + } else { + assertrx(buf.size() == count); + for (size_t i = 0; i < std::min(count, kMaxArrayOutput); ++i) { + if (i != 0) { + dump << ", "; + } + dump << buf[i].As(); + } + if (count > kMaxArrayOutput) { + dump << " ...]"; + } else { + dump << ']'; } - dump << buf[i].As(); } - dump << ']'; } } break; case TAG_NULL: @@ -309,6 +398,7 @@ static void dumpCjsonObject(Serializer& cjson, std::ostream& dump, const TagsMat case TAG_STRING: case TAG_BOOL: case TAG_UUID: + case TAG_FLOAT: indent(indentLevel + 1); dumpCjsonValue(type, cjson, dump); dump << '\n'; @@ -352,12 +442,19 @@ static void dumpCjsonObject(Serializer& cjson, std::ostream& dump, const TagsMat } indent(indentLevel + 2); } else { - for (size_t i = 0; i < count; ++i) { + size_t i = 0; + for (; i < std::min(count, kMaxArrayOutput); ++i) { if (i != 0) { dump << ", "; } dumpCjsonValue(arr.Type(), cjson, dump); } + for (; i < count; ++i) { + skipCjsonValue(arr.Type(), cjson); + } + if (count > kMaxArrayOutput) { + dump << ", ..."; + } } dump << "]\n"; } diff --git a/cpp_src/core/cjson/cjsontools.h b/cpp_src/core/cjson/cjsontools.h index b0a10dec2..ee9d6b342 100644 --- a/cpp_src/core/cjson/cjsontools.h +++ b/cpp_src/core/cjson/cjsontools.h @@ -16,6 +16,9 @@ void putCJsonValue(TagType tagType, int tagName, const VariantArray& values, WrS void skipCjsonTag(ctag tag, Serializer& rdser, std::array* fieldsArrayOffsets = nullptr); [[nodiscard]] Variant cjsonValueToVariant(TagType tag, Serializer& rdser, KeyValueType dstType); +[[noreturn]] void throwUnexpectedArrayError(std::string_view parserName, const PayloadFieldType&); +[[noreturn]] void throwUnexpectedArraySizeForFloatVectorError(std::string_view parserName, const PayloadFieldType& fieldRef, size_t size); +[[noreturn]] void throwUnexpectedArrayTypeForFloatVectorError(std::string_view parserName, const PayloadFieldType& fieldRef); [[noreturn]] void throwUnexpectedNestedArrayError(std::string_view parserName, const PayloadFieldType& f); [[noreturn]] void throwScalarMultipleEncodesError(const Payload& pl, const PayloadFieldType& f, int field); [[noreturn]] void throwUnexpectedArraySizeError(std::string_view parserName, const PayloadFieldType& f, int arraySize); @@ -33,7 +36,7 @@ RX_ALWAYS_INLINE void validateNonArrayFieldRestrictions(const ScalarIndexesSetT& RX_ALWAYS_INLINE void validateArrayFieldRestrictions(const PayloadFieldType& f, int arraySize, std::string_view parserName) { if (f.IsArray()) { - if rx_unlikely (arraySize && f.ArrayDim() > 0 && f.ArrayDim() != arraySize) { + if rx_unlikely (arraySize && f.ArrayDims() > 0 && int(f.ArrayDims()) != arraySize) { throwUnexpectedArraySizeError(parserName, f, arraySize); } } diff --git a/cpp_src/core/cjson/csvbuilder.cc b/cpp_src/core/cjson/csvbuilder.cc index bd299c2b4..d37a450b8 100644 --- a/cpp_src/core/cjson/csvbuilder.cc +++ b/cpp_src/core/cjson/csvbuilder.cc @@ -200,7 +200,7 @@ CsvBuilder& CsvBuilder::Null(std::string_view name) { CsvBuilder& CsvBuilder::Put(std::string_view name, const Variant& kv, int offset) { kv.Type().EvaluateOneOf( [&](KeyValueType::Int) { Put(name, int(kv), offset); }, [&](KeyValueType::Int64) { Put(name, int64_t(kv), offset); }, - [&](KeyValueType::Double) { Put(name, double(kv), offset); }, + [&](KeyValueType::Double) { Put(name, double(kv), offset); }, [&](KeyValueType::Float) { Put(name, float(kv), offset); }, [&](KeyValueType::String) { Put(name, std::string_view(kv), offset); }, [&](KeyValueType::Null) { Null(name); }, [&](KeyValueType::Bool) { Put(name, bool(kv), offset); }, [&](KeyValueType::Tuple) { @@ -209,7 +209,8 @@ CsvBuilder& CsvBuilder::Put(std::string_view name, const Variant& kv, int offset arrNode.Put({nullptr, 0}, val); } }, - [&](KeyValueType::Uuid) { Put(name, Uuid{kv}, offset); }, [](OneOf) noexcept {}); + [&](KeyValueType::Uuid) { Put(name, Uuid{kv}, offset); }, + [](OneOf) noexcept { assertrx_throw(false); }); return *this; } diff --git a/cpp_src/core/cjson/csvbuilder.h b/cpp_src/core/cjson/csvbuilder.h index b51d70b8a..4105a6d91 100644 --- a/cpp_src/core/cjson/csvbuilder.h +++ b/cpp_src/core/cjson/csvbuilder.h @@ -1,7 +1,7 @@ #pragma once #include -#include "estl/span.h" +#include #include "objtype.h" #include "tagslengths.h" #include "tagsmatcher.h" @@ -51,14 +51,14 @@ class CsvBuilder { CsvBuilder Array(int tagName, int size = KUnknownFieldSize) { return Array(getNameByTag(tagName), size); } template - void Array(int tagName, span data, int /*offset*/ = 0) { + void Array(int tagName, std::span data, int /*offset*/ = 0) { CsvBuilder node = Array(tagName); for (const auto& d : data) { node.Put({}, d); } } template - void Array(std::string_view n, span data, int /*offset*/ = 0) { + void Array(std::string_view n, std::span data, int /*offset*/ = 0) { CsvBuilder node = Array(n); for (const auto& d : data) { node.Put({}, d); diff --git a/cpp_src/core/cjson/fieldextractor.h b/cpp_src/core/cjson/fieldextractor.h index 5112b3110..e9f42bc57 100644 --- a/cpp_src/core/cjson/fieldextractor.h +++ b/cpp_src/core/cjson/fieldextractor.h @@ -1,7 +1,7 @@ #pragma once +#include #include "core/payload/fieldsset.h" -#include "estl/span.h" #include "tagsmatcher.h" namespace reindexer { @@ -29,18 +29,18 @@ class FieldsExtractor { FieldsExtractor Object(int) noexcept { return FieldsExtractor(values_, expectedType_, expectedPathDepth_ - 1, filter_, params_); } FieldsExtractor Array(int) noexcept { assertrx_throw(values_); - return FieldsExtractor(&values_->MarkArray(), expectedType_, expectedPathDepth_ - 1, filter_, params_); + return FieldsExtractor(&values_->MarkArray(), expectedType_, expectedPathDepth_, filter_, params_); } FieldsExtractor Object(std::string_view) noexcept { return FieldsExtractor(values_, expectedType_, expectedPathDepth_ - 1, filter_, params_); } FieldsExtractor Object(std::nullptr_t) noexcept { return Object(std::string_view{}); } FieldsExtractor Array(std::string_view) noexcept { - return FieldsExtractor(values_, expectedType_, expectedPathDepth_ - 1, filter_, params_); + return FieldsExtractor(&values_->MarkArray(), expectedType_, expectedPathDepth_, filter_, params_); } template - void Array(int, span data, int offset) { + void Array(int, std::span data, int offset) { const IndexedPathNode& pathNode = getArrayPathNode(); const PathType ptype = pathNotToType(pathNode); if (ptype == PathType::Other) { @@ -156,8 +156,10 @@ class FieldsExtractor { return *this; } expectedType_.EvaluateOneOf( - [&](OneOf) { arg.convert(expectedType_); }, + [&](OneOf) { + arg.convert(expectedType_); + }, [](OneOf) noexcept {}); assertrx_throw(values_); values_->emplace_back(std::move(arg)); diff --git a/cpp_src/core/cjson/jschemachecker.cc b/cpp_src/core/cjson/jschemachecker.cc index 46a904e6b..431490346 100644 --- a/cpp_src/core/cjson/jschemachecker.cc +++ b/cpp_src/core/cjson/jschemachecker.cc @@ -1,5 +1,4 @@ #include "jschemachecker.h" -#include #include #include #include "core/formatters/jsonstring_fmt.h" @@ -9,7 +8,7 @@ namespace reindexer { -JsonSchemaChecker::JsonSchemaChecker(const std::string& json, std::string rootTypeName) : rootTypeName_(std::move(rootTypeName)) { +JsonSchemaChecker::JsonSchemaChecker(std::string_view json, std::string rootTypeName) : rootTypeName_(std::move(rootTypeName)) { Error err = createTypeTable(json); if (!err.ok()) { throw err; @@ -17,7 +16,7 @@ JsonSchemaChecker::JsonSchemaChecker(const std::string& json, std::string rootTy isInit = true; } -Error JsonSchemaChecker::Init(const std::string& json, std::string rootTypeName) { +Error JsonSchemaChecker::Init(std::string_view json, std::string rootTypeName) { if (isInit) { return Error(errLogic, "JsonSchemaChecker already initialized."); } @@ -80,8 +79,8 @@ void JsonSchemaChecker::addSimpleType(std::string tpName) { indexes_.emplace(std::move(tpName), typesTable_.size() - 1); } -Error JsonSchemaChecker::createTypeTable(const std::string& json) { - auto err = schema_.FromJSON(std::string_view(json)); +Error JsonSchemaChecker::createTypeTable(std::string_view json) { + auto err = schema_.FromJSON(json); if (!err.ok()) { return err; } @@ -112,7 +111,7 @@ Error JsonSchemaChecker::createTypeTable(const std::string& json) { } Error JsonSchemaChecker::Check(gason::JsonNode node) { - if (node.value.getTag() != gason::JSON_OBJECT) { + if (node.value.getTag() != gason::JsonTag::OBJECT) { return Error(errParseJson, "Node [%s] should JSON_OBJECT.", node.key); } @@ -155,7 +154,7 @@ Error JsonSchemaChecker::checkScheme(const gason::JsonNode& node, int typeIndex, if (!err.ok()) { return err; } - if (elem.value.getTag() == gason::JSON_OBJECT) { + if (elem.value.getTag() == gason::JsonTag::OBJECT) { if (descr.subElementsTable[subElemIndex->second].second.typeName != "any") { err = checkScheme(elem, descr.subElementsTable[subElemIndex->second].second.typeIndex, path, descr.subElementsTable[subElemIndex->second].first); @@ -163,13 +162,13 @@ Error JsonSchemaChecker::checkScheme(const gason::JsonNode& node, int typeIndex, return err; } } - } else if (elem.value.getTag() == gason::JSON_ARRAY) { + } else if (elem.value.getTag() == gason::JsonTag::ARRAY) { if (descr.subElementsTable[subElemIndex->second].second.typeName != "any") { if (!descr.subElementsTable[subElemIndex->second].second.array) { return Error(errParseJson, "Element [%s] should array in [%s].", elem.key, path); } for (const auto& entry : elem.value) { - if (entry.value.getTag() == gason::JSON_ARRAY || entry.value.getTag() == gason::JSON_OBJECT) { + if (entry.value.getTag() == gason::JsonTag::ARRAY || entry.value.getTag() == gason::JsonTag::OBJECT) { err = checkScheme(entry, descr.subElementsTable[subElemIndex->second].second.typeIndex, path, descr.subElementsTable[subElemIndex->second].first); if (!err.ok()) { diff --git a/cpp_src/core/cjson/jschemachecker.h b/cpp_src/core/cjson/jschemachecker.h index f0c1abc19..b2275773a 100644 --- a/cpp_src/core/cjson/jschemachecker.h +++ b/cpp_src/core/cjson/jschemachecker.h @@ -11,9 +11,9 @@ namespace reindexer { class JsonSchemaChecker { public: - explicit JsonSchemaChecker(const std::string& json, std::string rootTypeName); + explicit JsonSchemaChecker(std::string_view json, std::string rootTypeName); JsonSchemaChecker() {}; - Error Init(const std::string& json, std::string rootTypeName); + Error Init(std::string_view json, std::string rootTypeName); Error Check(gason::JsonNode node); private: @@ -45,7 +45,7 @@ class JsonSchemaChecker { Error checkScheme(const gason::JsonNode& node, int typeIndex, std::string& path, const std::string& elementName); std::string createType(const PrefixTree::PrefixTreeNode* node, const std::string& typeName = ""); - Error createTypeTable(const std::string& json); + Error createTypeTable(std::string_view json); static bool isSimpleType(std::string_view tp); void addSimpleType(std::string tpName); Error checkExists(std::string_view name, ValAppearance* element, const std::string& path); diff --git a/cpp_src/core/cjson/jsonbuilder.cc b/cpp_src/core/cjson/jsonbuilder.cc index e93c05d05..75779ad42 100644 --- a/cpp_src/core/cjson/jsonbuilder.cc +++ b/cpp_src/core/cjson/jsonbuilder.cc @@ -83,7 +83,7 @@ JsonBuilder& JsonBuilder::Null(std::string_view name) { JsonBuilder& JsonBuilder::Put(std::string_view name, const Variant& kv, int offset) { kv.Type().EvaluateOneOf( [&](KeyValueType::Int) { Put(name, int(kv), offset); }, [&](KeyValueType::Int64) { Put(name, int64_t(kv), offset); }, - [&](KeyValueType::Double) { Put(name, double(kv), offset); }, + [&](KeyValueType::Double) { Put(name, double(kv), offset); }, [&](KeyValueType::Float) { Put(name, float(kv), offset); }, [&](KeyValueType::String) { Put(name, std::string_view(kv), offset); }, [&](KeyValueType::Null) { Null(name); }, [&](KeyValueType::Bool) { Put(name, bool(kv), offset); }, [&](KeyValueType::Tuple) { @@ -92,7 +92,8 @@ JsonBuilder& JsonBuilder::Put(std::string_view name, const Variant& kv, int offs arrNode.Put({nullptr, 0}, val, offset); } }, - [&](KeyValueType::Uuid) { Put(name, Uuid{kv}, offset); }, [](OneOf) noexcept {}); + [&](KeyValueType::Uuid) { Put(name, Uuid{kv}, offset); }, + [](OneOf) noexcept {}); return *this; } diff --git a/cpp_src/core/cjson/jsonbuilder.h b/cpp_src/core/cjson/jsonbuilder.h index 9934a99c8..211546a3e 100644 --- a/cpp_src/core/cjson/jsonbuilder.h +++ b/cpp_src/core/cjson/jsonbuilder.h @@ -1,6 +1,6 @@ #pragma once -#include "estl/span.h" +#include #include "objtype.h" #include "tagslengths.h" #include "tagsmatcher.h" @@ -31,14 +31,14 @@ class JsonBuilder { JsonBuilder Array(int tagName, int size = KUnknownFieldSize) { return Array(getNameByTag(tagName), size); } template - void Array(int tagName, span data, int /*offset*/ = 0) { + void Array(int tagName, std::span data, int /*offset*/ = 0) { JsonBuilder node = Array(tagName); for (const auto& d : data) { node.Put({}, d); } } template - void Array(std::string_view n, span data, int /*offset*/ = 0) { + void Array(std::string_view n, std::span data, int /*offset*/ = 0) { JsonBuilder node = Array(n); for (const auto& d : data) { node.Put({}, d); diff --git a/cpp_src/core/cjson/jsondecoder.cc b/cpp_src/core/cjson/jsondecoder.cc index 2a0e2eba8..8aeb3018b 100644 --- a/cpp_src/core/cjson/jsondecoder.cc +++ b/cpp_src/core/cjson/jsondecoder.cc @@ -9,12 +9,12 @@ namespace reindexer { -Error JsonDecoder::Decode(Payload& pl, WrSerializer& wrser, const gason::JsonValue& v) { +Error JsonDecoder::Decode(Payload& pl, WrSerializer& wrser, const gason::JsonValue& v, FloatVectorsHolderVector& floatVectorsHolder) { try { objectScalarIndexes_.reset(); tagsPath_.clear(); CJsonBuilder builder(wrser, ObjType::TypePlain, &tagsMatcher_); - decodeJson(&pl, builder, v, 0, true); + decodeJson(&pl, builder, v, 0, floatVectorsHolder, true); } catch (const Error& err) { @@ -23,7 +23,9 @@ Error JsonDecoder::Decode(Payload& pl, WrSerializer& wrser, const gason::JsonVal return {}; } -void JsonDecoder::decodeJsonObject(Payload& pl, CJsonBuilder& builder, const gason::JsonValue& v, bool match) { +void JsonDecoder::decodeJsonObject(Payload& pl, CJsonBuilder& builder, const gason::JsonValue& v, + FloatVectorsHolderVector& floatVectorsHolder, bool match) { + using namespace std::string_view_literals; for (const auto& elem : v) { int tagName = tagsMatcher_.name2tag(elem.key, true); assertrx(tagName); @@ -38,46 +40,57 @@ void JsonDecoder::decodeJsonObject(Payload& pl, CJsonBuilder& builder, const gas } if (field < 0) { - decodeJson(&pl, builder, elem.value, tagName, match); + decodeJson(&pl, builder, elem.value, tagName, floatVectorsHolder, match); } else if (match) { // Indexed field. extract it const auto& f = pl.Type().Field(field); switch (elem.value.getTag()) { - case gason::JSON_ARRAY: { - if rx_unlikely (!f.IsArray()) { - throw Error(errLogic, "Error parsing json field '%s' - got array, expected scalar %s", f.Name(), f.Type().Name()); - } - int count = 0; - for (auto& subelem : elem.value) { - (void)subelem; - ++count; - } - validateArrayFieldRestrictions(f, count, "json"); - int pos = pl.ResizeArray(field, count, true); - for (auto& subelem : elem.value) { - pl.Set(field, pos++, jsonValue2Variant(subelem.value, f.Type(), f.Name())); - } - builder.ArrayRef(tagName, field, count); - } break; - case gason::JSON_NULL: + case gason::JsonTag::JSON_NULL: validateNonArrayFieldRestrictions(objectScalarIndexes_, pl, f, field, isInArray(), "json"); objectScalarIndexes_.set(field); builder.Null(tagName); break; - case gason::JSON_NUMBER: - case gason::JSON_DOUBLE: - case gason::JSON_OBJECT: - case gason::JSON_STRING: - case gason::JSON_TRUE: - case gason::JSON_FALSE: { + case gason::JsonTag::ARRAY: + if (f.Type().Is()) { + validateNonArrayFieldRestrictions(objectScalarIndexes_, pl, f, field, isInArray(), "json"); + validateArrayFieldRestrictions(f, 1, "json"); + objectScalarIndexes_.set(field); + Variant value = jsonValue2Variant(elem.value, f.Type(), f.Name(), &floatVectorsHolder); + assertrx_dbg(value.Type().Is()); + const auto count = ConstFloatVectorView(value).Dimension(); + pl.Set(field, std::move(value)); + builder.ArrayRef(tagName, field, int(count)); + } else { + if rx_unlikely (!f.IsArray()) { + throwUnexpectedArrayError("json"sv, f); + } + int count = 0; + for (auto& subelem : elem.value) { + (void)subelem; + ++count; + } + validateArrayFieldRestrictions(f, count, "json"); + int pos = pl.ResizeArray(field, count, true); + for (auto& subelem : elem.value) { + pl.Set(field, pos++, jsonValue2Variant(subelem.value, f.Type(), f.Name(), nullptr)); + } + builder.ArrayRef(tagName, field, count); + } + break; + case gason::JsonTag::NUMBER: + case gason::JsonTag::DOUBLE: + case gason::JsonTag::OBJECT: + case gason::JsonTag::STRING: + case gason::JsonTag::JTRUE: + case gason::JsonTag::JFALSE: { validateNonArrayFieldRestrictions(objectScalarIndexes_, pl, f, field, isInArray(), "json"); validateArrayFieldRestrictions(f, 1, "json"); objectScalarIndexes_.set(field); - Variant value = jsonValue2Variant(elem.value, f.Type(), f.Name()); - builder.Ref(tagName, value, field); + Variant value = jsonValue2Variant(elem.value, f.Type(), f.Name(), nullptr); + builder.Ref(tagName, value.Type(), field); pl.Set(field, std::move(value), true); } break; - case gason::JSON_EMPTY: + case gason::JsonTag::EMPTY: default: throw Error(errLogic, "Unexpected '%d' tag", elem.value.getTag()); } @@ -91,51 +104,52 @@ void JsonDecoder::decodeJsonObject(Payload& pl, CJsonBuilder& builder, const gas // Split original JSON into 2 parts: // 1. PayloadFields - fields from json found by 'jsonPath' tags // 2. stripped binary packed JSON without fields values found by 'jsonPath' tags -void JsonDecoder::decodeJson(Payload* pl, CJsonBuilder& builder, const gason::JsonValue& v, int tagName, bool match) { +void JsonDecoder::decodeJson(Payload* pl, CJsonBuilder& builder, const gason::JsonValue& v, int tagName, + FloatVectorsHolderVector& floatVectorsHolder, bool match) { auto jsonTag = v.getTag(); - if (!match && jsonTag != gason::JSON_OBJECT) { + if (!match && jsonTag != gason::JsonTag::OBJECT) { return; } switch (jsonTag) { - case gason::JSON_NUMBER: { + case gason::JsonTag::NUMBER: { int64_t value = v.toNumber(); builder.Put(tagName, int64_t(value)); } break; - case gason::JSON_DOUBLE: { + case gason::JsonTag::DOUBLE: { double value = v.toDouble(); builder.Put(tagName, value); } break; - case gason::JSON_STRING: + case gason::JsonTag::STRING: builder.Put(tagName, v.toString()); break; - case gason::JSON_TRUE: + case gason::JsonTag::JTRUE: builder.Put(tagName, true); break; - case gason::JSON_FALSE: + case gason::JsonTag::JFALSE: builder.Put(tagName, false); break; - case gason::JSON_NULL: + case gason::JsonTag::JSON_NULL: builder.Null(tagName); break; - case gason::JSON_ARRAY: { + case gason::JsonTag::ARRAY: { CounterGuardIR32 g(arrayLevel_); ObjType type = (gason::isHomogeneousArray(v)) ? ObjType::TypeArray : ObjType::TypeObjectArray; auto arrNode = builder.Array(tagName, type); for (const auto& elem : v) { - decodeJson(pl, arrNode, elem.value, 0, match); + decodeJson(pl, arrNode, elem.value, 0, floatVectorsHolder, match); } break; } - case gason::JSON_OBJECT: { + case gason::JsonTag::OBJECT: { auto objNode = builder.Object(tagName); if (pl) { - decodeJsonObject(*pl, objNode, v, match); + decodeJsonObject(*pl, objNode, v, floatVectorsHolder, match); } else { - decodeJsonObject(v, objNode); + decodeJsonObject(v, objNode, floatVectorsHolder); } break; } - case gason::JSON_EMPTY: + case gason::JsonTag::EMPTY: default: throw Error(errLogic, "Unexpected '%d' tag", jsonTag); } @@ -150,24 +164,25 @@ class TagsPathGuard { TagsPath& tagsPath_; }; -void JsonDecoder::decodeJsonObject(const gason::JsonValue& root, CJsonBuilder& builder) { +void JsonDecoder::decodeJsonObject(const gason::JsonValue& root, CJsonBuilder& builder, FloatVectorsHolderVector& floatVectorsHolder) { for (const auto& elem : root) { const int tagName = tagsMatcher_.name2tag(elem.key, true); if (tagName == 0) { throw Error(errParseJson, "Unsupported JSON format. Unnamed field detected"); } TagsPathGuard tagsPathGuard(tagsPath_, tagName); - decodeJson(nullptr, builder, elem.value, tagName, true); + decodeJson(nullptr, builder, elem.value, tagName, floatVectorsHolder, true); } } -void JsonDecoder::Decode(std::string_view json, CJsonBuilder& builder, const TagsPath& fieldPath) { +void JsonDecoder::Decode(std::string_view json, CJsonBuilder& builder, const TagsPath& fieldPath, + FloatVectorsHolderVector& floatVectorsHolder) { try { objectScalarIndexes_.reset(); tagsPath_ = fieldPath; gason::JsonParser jsonParser; gason::JsonNode root = jsonParser.Parse(json); - decodeJsonObject(root.value, builder); + decodeJsonObject(root.value, builder, floatVectorsHolder); } catch (gason::Exception& e) { throw Error(errParseJson, "JSONDecoder: %s", e.what()); } diff --git a/cpp_src/core/cjson/jsondecoder.h b/cpp_src/core/cjson/jsondecoder.h index 066f6740a..39dad6200 100644 --- a/cpp_src/core/cjson/jsondecoder.h +++ b/cpp_src/core/cjson/jsondecoder.h @@ -1,6 +1,7 @@ #pragma once #include "cjsonbuilder.h" +#include "core/keyvalue/float_vectors_holder.h" #include "core/payload/payloadiface.h" #include "gason/gason.h" @@ -10,13 +11,13 @@ class JsonDecoder { public: explicit JsonDecoder(TagsMatcher& tagsMatcher, const FieldsSet* filter = nullptr) noexcept : tagsMatcher_(tagsMatcher), filter_(filter) {} - Error Decode(Payload& pl, WrSerializer& wrSer, const gason::JsonValue& v); - void Decode(std::string_view json, CJsonBuilder& builder, const TagsPath& fieldPath); + Error Decode(Payload& pl, WrSerializer& wrSer, const gason::JsonValue& v, FloatVectorsHolderVector&); + void Decode(std::string_view json, CJsonBuilder& builder, const TagsPath& fieldPath, FloatVectorsHolderVector&); private: - void decodeJsonObject(const gason::JsonValue& root, CJsonBuilder& builder); - void decodeJsonObject(Payload& pl, CJsonBuilder& builder, const gason::JsonValue& v, bool match); - void decodeJson(Payload* pl, CJsonBuilder& builder, const gason::JsonValue& v, int tag, bool match); + void decodeJsonObject(const gason::JsonValue& root, CJsonBuilder& builder, FloatVectorsHolderVector&); + void decodeJsonObject(Payload& pl, CJsonBuilder& builder, const gason::JsonValue& v, FloatVectorsHolderVector&, bool match); + void decodeJson(Payload* pl, CJsonBuilder& builder, const gason::JsonValue& v, int tag, FloatVectorsHolderVector&, bool match); bool isInArray() const noexcept { return arrayLevel_ > 0; } TagsMatcher& tagsMatcher_; diff --git a/cpp_src/core/cjson/msgpackbuilder.cc b/cpp_src/core/cjson/msgpackbuilder.cc index 149754ad0..f3c08ac42 100644 --- a/cpp_src/core/cjson/msgpackbuilder.cc +++ b/cpp_src/core/cjson/msgpackbuilder.cc @@ -96,6 +96,9 @@ void MsgPackBuilder::packCJsonValue(TagType tagType, Serializer& rdser) { case TAG_UUID: packValue(rdser.GetUuid()); break; + case TAG_FLOAT: + packValue(rdser.GetFloat()); + break; case TAG_NULL: packNil(); break; @@ -109,26 +112,25 @@ void MsgPackBuilder::packCJsonValue(TagType tagType, Serializer& rdser) { void MsgPackBuilder::appendJsonObject(std::string_view name, const gason::JsonNode& obj) { auto type = obj.value.getTag(); switch (type) { - case gason::JSON_STRING: { + case gason::JsonTag::STRING: { Put(name, obj.As(), 0); break; } - case gason::JSON_NUMBER: { + case gason::JsonTag::NUMBER: Put(name, obj.As(), 0); break; - } - case gason::JSON_DOUBLE: { + case gason::JsonTag::DOUBLE: { Put(name, obj.As(), 0); break; } - case gason::JSON_OBJECT: - case gason::JSON_ARRAY: { + case gason::JsonTag::OBJECT: + case gason::JsonTag::ARRAY: { int size = 0; for (const auto& node : obj) { (void)node; ++size; } - if (type == gason::JSON_OBJECT) { + if (type == gason::JsonTag::OBJECT) { auto pack = Object(name, size); for (const auto& node : obj) { pack.appendJsonObject(std::string_view(node.key), node); @@ -141,19 +143,19 @@ void MsgPackBuilder::appendJsonObject(std::string_view name, const gason::JsonNo } break; } - case gason::JSON_TRUE: { + case gason::JsonTag::JTRUE: { Put(std::string_view(obj.key), true, 0); break; } - case gason::JSON_FALSE: { + case gason::JsonTag::JFALSE: { Put(std::string_view(obj.key), false, 0); break; } - case gason::JSON_NULL: { + case gason::JsonTag::JSON_NULL: { Null(std::string_view(obj.key)); break; } - case gason::JSON_EMPTY: + case gason::JsonTag::EMPTY: default: throw(Error(errLogic, "Unexpected json tag for Object: %d", int(obj.value.getTag()))); } diff --git a/cpp_src/core/cjson/msgpackbuilder.h b/cpp_src/core/cjson/msgpackbuilder.h index fa6244d1e..67b4dd1df 100644 --- a/cpp_src/core/cjson/msgpackbuilder.h +++ b/cpp_src/core/cjson/msgpackbuilder.h @@ -5,7 +5,7 @@ #include "core/cjson/tagsmatcher.h" #include "core/keyvalue/p_string.h" #include "core/payload/payloadiface.h" -#include "estl/span.h" +#include #include "vendor/msgpack/msgpack.h" namespace gason { @@ -36,7 +36,7 @@ class MsgPackBuilder { MsgPackBuilder Raw(std::nullptr_t, std::string_view arg) { return Raw(std::string_view{}, arg); } template - void Array(N tagName, span data, int /*offset*/ = 0) { + void Array(N tagName, std::span data, int /*offset*/ = 0) { checkIfCorrectArray(tagName); skipTag(); packKeyName(tagName); @@ -46,7 +46,7 @@ class MsgPackBuilder { } } template - void Array(N tagName, span data, int /*offset*/ = 0) { + void Array(N tagName, std::span data, int /*offset*/ = 0) { checkIfCorrectArray(tagName); skipTag(); packKeyName(tagName); @@ -57,7 +57,7 @@ class MsgPackBuilder { } template - void Array(T tagName, span data, int /*offset*/ = 0) { + void Array(T tagName, std::span data, int /*offset*/ = 0) { checkIfCorrectArray(tagName); skipTag(); packKeyName(tagName); @@ -139,15 +139,17 @@ class MsgPackBuilder { packKeyName(tagName); kv.Type().EvaluateOneOf( [&](KeyValueType::Int) { packValue(int(kv)); }, [&](KeyValueType::Int64) { packValue(int64_t(kv)); }, - [&](KeyValueType::Double) { packValue(double(kv)); }, [&](KeyValueType::String) { packValue(std::string_view(kv)); }, - [&](KeyValueType::Null) { packNil(); }, [&](KeyValueType::Bool) { packValue(bool(kv)); }, + [&](KeyValueType::Double) { packValue(double(kv)); }, [&](KeyValueType::Float) { packValue(float(kv)); }, + [&](KeyValueType::String) { packValue(std::string_view(kv)); }, [&](KeyValueType::Null) { packNil(); }, + [&](KeyValueType::Bool) { packValue(bool(kv)); }, [&](KeyValueType::Tuple) { auto arrNode = Array(tagName); for (auto& val : kv.getCompositeValues()) { arrNode.Put(0, val, offset); } }, - [&](KeyValueType::Uuid) { packValue(Uuid{kv}); }, [](OneOf) noexcept {}); + [&](KeyValueType::Uuid) { packValue(Uuid{kv}); }, + [](OneOf) noexcept { assertrx_throw(false); }); if (isArray()) { skipTag(); } @@ -168,6 +170,7 @@ class MsgPackBuilder { void packValue(int arg) { msgpack_pack_int(&packer_, arg); } void packValue(int64_t arg) { msgpack_pack_int64(&packer_, arg); } void packValue(double arg) { msgpack_pack_double(&packer_, arg); } + void packValue(float arg) { msgpack_pack_float(&packer_, arg); } void packValue(std::string_view arg) { msgpack_pack_str(&packer_, arg.size()); diff --git a/cpp_src/core/cjson/msgpackdecoder.cc b/cpp_src/core/cjson/msgpackdecoder.cc index 0c7cf525c..47d04594f 100644 --- a/cpp_src/core/cjson/msgpackdecoder.cc +++ b/cpp_src/core/cjson/msgpackdecoder.cc @@ -1,8 +1,10 @@ #include "msgpackdecoder.h" +#include "core/cjson/cjsonbuilder.h" #include "core/cjson/cjsontools.h" #include "core/cjson/objtype.h" #include "core/cjson/tagsmatcher.h" +#include "core/keyvalue/float_vectors_holder.h" #include "tools/flagguard.h" namespace reindexer { @@ -17,7 +19,7 @@ void MsgPackDecoder::setValue(Payload& pl, CJsonBuilder& builder, const T& value validateArrayFieldRestrictions(f, 1, "msgpack"); } Variant val(value); - builder.Ref(tagName, val, field); + builder.Ref(tagName, val.Type(), field); pl.Set(field, convertValueForPayload(pl, field, std::move(val), "msgpack")); objectScalarIndexes_.set(field); } else { @@ -48,7 +50,9 @@ int MsgPackDecoder::decodeKeyToTag(const msgpack_object_kv& obj) { throw Error(errParams, "Unsupported MsgPack map key type: %s(%d)", ToString(obj.key.type), int(obj.key.type)); } -void MsgPackDecoder::decode(Payload& pl, CJsonBuilder& builder, const msgpack_object& obj, int tagName) { +void MsgPackDecoder::decode(Payload& pl, CJsonBuilder& builder, const msgpack_object& obj, int tagName, + FloatVectorsHolderVector& floatVectorsHolder) { + using namespace std::string_view_literals; if (tagName) { tagsPath_.emplace_back(tagName); } @@ -73,7 +77,7 @@ void MsgPackDecoder::decode(Payload& pl, CJsonBuilder& builder, const msgpack_ob setValue(pl, builder, p_string(reinterpret_cast(&obj.via.str)), tagName); break; case MSGPACK_OBJECT_ARRAY: { - int count = 0; + size_t count = 0; CounterGuardIR32 g(arrayLevel_); const msgpack_object* begin = obj.via.array.ptr; const msgpack_object* end = begin + obj.via.array.size; @@ -89,46 +93,81 @@ void MsgPackDecoder::decode(Payload& pl, CJsonBuilder& builder, const msgpack_ob } int field = tm_.tags2field(tagsPath_.data(), tagsPath_.size()); if (field > 0) { - auto& f = pl.Type().Field(field); - if rx_unlikely (!f.IsArray()) { - throw Error(errLogic, "Error parsing msgpack field '%s' - got array, expected scalar %s", f.Name(), f.Type().Name()); - } - validateArrayFieldRestrictions(f, count, "msgpack"); - int pos = pl.ResizeArray(field, count, true); - for (const msgpack_object* p = begin; p != end; ++p) { - pl.Set(field, pos++, - convertValueForPayload( - pl, field, - [&] { - switch (p->type) { - case MSGPACK_OBJECT_BOOLEAN: - return Variant{p->via.boolean}; - case MSGPACK_OBJECT_POSITIVE_INTEGER: - return Variant{int64_t(p->via.u64)}; - case MSGPACK_OBJECT_NEGATIVE_INTEGER: - return Variant{p->via.i64}; - case MSGPACK_OBJECT_FLOAT32: - case MSGPACK_OBJECT_FLOAT64: - return Variant{p->via.f64}; - case MSGPACK_OBJECT_STR: - return Variant{p_string(reinterpret_cast(&p->via.str)), Variant::hold_t{}}; - case MSGPACK_OBJECT_NIL: - case MSGPACK_OBJECT_ARRAY: - case MSGPACK_OBJECT_MAP: - case MSGPACK_OBJECT_BIN: - case MSGPACK_OBJECT_EXT: - default: - throw Error(errParams, "Unsupported MsgPack array field type: %s(%d)", ToString(p->type), - int(p->type)); - } - }(), - "msgpack")); + const auto& f = pl.Type().Field(field); + if (f.IsFloatVector()) { + ConstFloatVectorView vectView; + if (count != 0) { + if (count != size_t(f.FloatVectorDimension())) { + throwUnexpectedArraySizeForFloatVectorError("msgpack"sv, f, count); + } + auto vect = FloatVector::CreateNotInitialized(f.FloatVectorDimension()); + size_t pos = 0; + for (const msgpack_object* p = begin; p != end; ++p, ++pos) { + assertrx(pos < size_t(f.FloatVectorDimension())); + switch (p->type) { + case MSGPACK_OBJECT_FLOAT32: + case MSGPACK_OBJECT_FLOAT64: + vect.RawData()[pos] = p->via.f64; + break; + case MSGPACK_OBJECT_BOOLEAN: + case MSGPACK_OBJECT_POSITIVE_INTEGER: + case MSGPACK_OBJECT_NEGATIVE_INTEGER: + case MSGPACK_OBJECT_STR: + case MSGPACK_OBJECT_NIL: + case MSGPACK_OBJECT_ARRAY: + case MSGPACK_OBJECT_MAP: + case MSGPACK_OBJECT_BIN: + case MSGPACK_OBJECT_EXT: + default: + throwUnexpectedArrayTypeForFloatVectorError("msgpack"sv, f); + } + } + floatVectorsHolder.Add(std::move(vect)); + vectView = floatVectorsHolder.Back(); + } + pl.Set(field, Variant{vectView}); + } else { + if rx_unlikely (!f.IsArray()) { + throwUnexpectedArrayError("msgpack"sv, f); + } + validateArrayFieldRestrictions(f, count, "msgpack"); + int pos = pl.ResizeArray(field, count, true); + for (const msgpack_object* p = begin; p != end; ++p) { + pl.Set(field, pos++, + convertValueForPayload( + pl, field, + [&] { + switch (p->type) { + case MSGPACK_OBJECT_BOOLEAN: + return Variant{p->via.boolean}; + case MSGPACK_OBJECT_POSITIVE_INTEGER: + return Variant{int64_t(p->via.u64)}; + case MSGPACK_OBJECT_NEGATIVE_INTEGER: + return Variant{p->via.i64}; + case MSGPACK_OBJECT_FLOAT32: + case MSGPACK_OBJECT_FLOAT64: + return Variant{p->via.f64}; + case MSGPACK_OBJECT_STR: + return Variant{p_string(reinterpret_cast(&p->via.str)), + Variant::HoldT{}}; + case MSGPACK_OBJECT_NIL: + case MSGPACK_OBJECT_ARRAY: + case MSGPACK_OBJECT_MAP: + case MSGPACK_OBJECT_BIN: + case MSGPACK_OBJECT_EXT: + default: + throw Error(errParams, "Unsupported MsgPack array field type: %s(%d)", ToString(p->type), + int(p->type)); + } + }(), + "msgpack")); + } } builder.ArrayRef(tagName, field, count); } else { auto array = builder.Array(tagName, type); for (const msgpack_object* p = begin; p != end; ++p) { - decode(pl, array, *p, 0); + decode(pl, array, *p, 0, floatVectorsHolder); } } break; @@ -141,7 +180,7 @@ void MsgPackDecoder::decode(Payload& pl, CJsonBuilder& builder, const msgpack_ob // MsgPack can have non-string type keys: https://github.com/msgpack/msgpack/issues/217 assertrx(p); int tag = decodeKeyToTag(*p); - decode(pl, object, p->val, tag); + decode(pl, object, p->val, tag, floatVectorsHolder); } break; } @@ -155,7 +194,8 @@ void MsgPackDecoder::decode(Payload& pl, CJsonBuilder& builder, const msgpack_ob } } -Error MsgPackDecoder::Decode(std::string_view buf, Payload& pl, WrSerializer& wrser, size_t& offset) { +Error MsgPackDecoder::Decode(std::string_view buf, Payload& pl, WrSerializer& wrser, size_t& offset, + FloatVectorsHolderVector& floatVectorsHolder) { try { objectScalarIndexes_.reset(); tagsPath_.clear(); @@ -171,7 +211,7 @@ Error MsgPackDecoder::Decode(std::string_view buf, Payload& pl, WrSerializer& wr } CJsonBuilder cjsonBuilder(wrser, ObjType::TypePlain, &tm_, 0); - decode(pl, cjsonBuilder, *(data.p), 0); + decode(pl, cjsonBuilder, *(data.p), 0, floatVectorsHolder); } catch (const Error& err) { return err; } catch (const std::exception& ex) { diff --git a/cpp_src/core/cjson/msgpackdecoder.h b/cpp_src/core/cjson/msgpackdecoder.h index e89f2ff8e..0308e253c 100644 --- a/cpp_src/core/cjson/msgpackdecoder.h +++ b/cpp_src/core/cjson/msgpackdecoder.h @@ -1,6 +1,5 @@ #pragma once -#include "core/cjson/cjsonbuilder.h" #include "core/payload/payloadiface.h" #include "tools/errors.h" #include "vendor/msgpack/msgpackparser.h" @@ -11,14 +10,16 @@ namespace reindexer { class TagsMatcher; class WrSerializer; +class CJsonBuilder; +class FloatVectorsHolderVector; class MsgPackDecoder { public: explicit MsgPackDecoder(TagsMatcher& tagsMatcher) noexcept : tm_(tagsMatcher) {} - Error Decode(std::string_view buf, Payload& pl, WrSerializer& wrser, size_t& offset); + Error Decode(std::string_view buf, Payload& pl, WrSerializer& wrser, size_t& offset, FloatVectorsHolderVector&); private: - void decode(Payload& pl, CJsonBuilder& builder, const msgpack_object& obj, int tagName); + void decode(Payload& pl, CJsonBuilder& builder, const msgpack_object& obj, int tagName, FloatVectorsHolderVector&); int decodeKeyToTag(const msgpack_object_kv& obj); diff --git a/cpp_src/core/cjson/protobufbuilder.cc b/cpp_src/core/cjson/protobufbuilder.cc index c9237a642..ec17b1eaa 100644 --- a/cpp_src/core/cjson/protobufbuilder.cc +++ b/cpp_src/core/cjson/protobufbuilder.cc @@ -54,6 +54,9 @@ void ProtobufBuilder::packItem(int fieldIdx, TagType tagType, Serializer& rdser, case TAG_UUID: array.put(fieldIdx, rdser.GetUuid()); break; + case TAG_FLOAT: + array.put(fieldIdx, rdser.GetFloat()); + break; case TAG_NULL: array.Null(fieldIdx); break; @@ -96,8 +99,12 @@ void ProtobufBuilder::put(int fieldIdx, int val) { put(fieldIdx, double(val)); done = true; }, + [&](KeyValueType::Float) { + put(fieldIdx, float(val)); + done = true; + }, [&](OneOf) { + KeyValueType::Undefined, KeyValueType::Uuid, KeyValueType::FloatVector>) { throw Error(errParams, "Expected type '%s' for field '%s'", res.first.Name(), tm_->tag2name(fieldIdx)); }); } @@ -121,8 +128,12 @@ void ProtobufBuilder::put(int fieldIdx, int64_t val) { put(fieldIdx, double(val)); done = true; }, + [&](KeyValueType::Float) { + put(fieldIdx, float(val)); + done = true; + }, [&](OneOf) { + KeyValueType::Undefined, KeyValueType::Uuid, KeyValueType::FloatVector>) { throw Error(errParams, "Expected type '%s' for field '%s'", res.first.Name(), tm_->tag2name(fieldIdx)); }); } @@ -142,12 +153,16 @@ void ProtobufBuilder::put(int fieldIdx, double val) { put(fieldIdx, int(val)); done = true; }, + [&](KeyValueType::Float) { + put(fieldIdx, float(val)); + done = true; + }, [&](KeyValueType::Int64) { put(fieldIdx, int64_t(val)); done = true; }, [&](OneOf) { + KeyValueType::Undefined, KeyValueType::Uuid, KeyValueType::FloatVector>) { throw Error(errParams, "Expected type '%s' for field '%s'", res.first.Name(), tm_->tag2name(fieldIdx)); }); } @@ -159,6 +174,36 @@ void ProtobufBuilder::put(int fieldIdx, double val) { } } +void ProtobufBuilder::put(int fieldIdx, float val) { + bool done = false; + if (const auto res = getExpectedFieldType(); res.second) { + res.first.EvaluateOneOf( + [&](KeyValueType::Double) noexcept { + put(fieldIdx, double(val)); + done = true; + }, + [&](OneOf) { + put(fieldIdx, int(val)); + done = true; + }, + [&](KeyValueType::Float) {}, + [&](KeyValueType::Int64) { + put(fieldIdx, int64_t(val)); + done = true; + }, + [&](OneOf) { + throw Error(errParams, "Expected type '%s' for field '%s'", res.first.Name(), tm_->tag2name(fieldIdx)); + }); + } + if (!done) { + if (type_ != ObjType::TypeArray) { + putFieldHeader(fieldIdx, PBUF_TYPE_FLOAT32); + } + ser_->PutFloat(val); + } +} + void ProtobufBuilder::put(int fieldIdx, std::string_view val) { if (const auto res = getExpectedFieldType(); res.second) { if (!res.first.Is()) { @@ -184,18 +229,18 @@ void ProtobufBuilder::put(int fieldIdx, Uuid val) { } void ProtobufBuilder::put(int fieldIdx, const Variant& val) { - val.Type().EvaluateOneOf([&](KeyValueType::Int64) { put(fieldIdx, int64_t(val)); }, [&](KeyValueType::Int) { put(fieldIdx, int(val)); }, - [&](KeyValueType::Double) { put(fieldIdx, double(val)); }, - [&](KeyValueType::String) { put(fieldIdx, std::string_view(val)); }, - [&](KeyValueType::Bool) { put(fieldIdx, bool(val)); }, - [&](KeyValueType::Tuple) { - auto arrNode = ArrayPacked(fieldIdx); - for (auto& itVal : val.getCompositeValues()) { - arrNode.Put(fieldIdx, itVal, 0); - } - }, - [&](KeyValueType::Uuid) { put(fieldIdx, Uuid{val}); }, - [&](OneOf) noexcept {}); + val.Type().EvaluateOneOf( + [&](KeyValueType::Int64) { put(fieldIdx, int64_t(val)); }, [&](KeyValueType::Int) { put(fieldIdx, int(val)); }, + [&](KeyValueType::Double) { put(fieldIdx, double(val)); }, [&](KeyValueType::Float) { put(fieldIdx, float(val)); }, + [&](KeyValueType::String) { put(fieldIdx, std::string_view(val)); }, [&](KeyValueType::Bool) { put(fieldIdx, bool(val)); }, + [&](KeyValueType::Tuple) { + auto arrNode = ArrayPacked(fieldIdx); + for (auto& itVal : val.getCompositeValues()) { + arrNode.Put(fieldIdx, itVal, 0); + } + }, + [&](KeyValueType::Uuid) { put(fieldIdx, Uuid{val}); }, + [&](OneOf) noexcept {}); } ProtobufBuilder ProtobufBuilder::Object(int fieldIdx, int) { diff --git a/cpp_src/core/cjson/protobufbuilder.h b/cpp_src/core/cjson/protobufbuilder.h index 513d7bf1c..21c5728fa 100644 --- a/cpp_src/core/cjson/protobufbuilder.h +++ b/cpp_src/core/cjson/protobufbuilder.h @@ -5,7 +5,7 @@ #include "core/cjson/tagslengths.h" #include "core/cjson/tagsmatcher.h" #include "core/keyvalue/p_string.h" -#include "estl/span.h" +#include #include "tools/serializer.h" namespace reindexer { @@ -70,7 +70,7 @@ class ProtobufBuilder { template ::value || std::is_floating_point::value || std::is_same::value>::type* = nullptr> - void Array(int fieldIdx, span data, int /*offset*/ = 0) { + void Array(int fieldIdx, std::span data, int /*offset*/ = 0) { auto array = ArrayPacked(fieldIdx); for (const T& item : data) { array.put(0, item); @@ -78,13 +78,13 @@ class ProtobufBuilder { } template ::value>::type* = nullptr> - void Array(int fieldIdx, span data, int /*offset*/ = 0) { + void Array(int fieldIdx, std::span data, int /*offset*/ = 0) { auto array = ArrayNotPacked(fieldIdx); for (const T& item : data) { array.put(fieldIdx, std::string_view(item)); } } - void Array(int fieldIdx, span data, int /*offset*/ = 0) { + void Array(int fieldIdx, std::span data, int /*offset*/ = 0) { auto array = ArrayNotPacked(fieldIdx); for (Uuid item : data) { array.put(fieldIdx, item); @@ -131,6 +131,7 @@ class ProtobufBuilder { void put(int fieldIdx, int val); void put(int fieldIdx, int64_t val); void put(int fieldIdx, double val); + void put(int fieldIdx, float val); void put(int fieldIdx, std::string_view val); void put(int fieldIdx, const Variant& val); void put(int fieldIdx, Uuid val); diff --git a/cpp_src/core/cjson/protobufdecoder.cc b/cpp_src/core/cjson/protobufdecoder.cc index 45d43d53b..039b9ac1c 100644 --- a/cpp_src/core/cjson/protobufdecoder.cc +++ b/cpp_src/core/cjson/protobufdecoder.cc @@ -1,5 +1,6 @@ #include "protobufdecoder.h" #include "core/cjson/cjsontools.h" +#include "core/keyvalue/float_vectors_holder.h" #include "core/schema.h" #include "estl/protobufparser.h" @@ -53,7 +54,7 @@ void ProtobufDecoder::setValue(Payload& pl, CJsonBuilder& builder, ProtobufValue arraysStorage_.UpdateArraySize(item.tagName, field); } else { validateArrayFieldRestrictions(f, 1, "protobuf"); - builder.Ref(item.tagName, value, field); + builder.Ref(item.tagName, value.Type(), field); } pl.Set(field, convertValueForPayload(pl, field, std::move(value), "protobuf"), true); objectScalarIndexes_.set(field); @@ -67,27 +68,55 @@ void ProtobufDecoder::setValue(Payload& pl, CJsonBuilder& builder, ProtobufValue } } -Error ProtobufDecoder::decodeArray(Payload& pl, CJsonBuilder& builder, const ProtobufValue& item) { +Error ProtobufDecoder::decodeArray(Payload& pl, CJsonBuilder& builder, const ProtobufValue& item, + FloatVectorsHolderVector& floatVectorsHolder) { + using namespace std::string_view_literals; ProtobufObject object(item.As(), *schema_, tagsPath_, tm_); ProtobufParser parser(object); const bool packed = item.IsOfPrimitiveType(); const int field = tm_.tags2field(tagsPath_.data(), tagsPath_.size()); if (field > 0) { const auto& f = pl.Type().Field(field); - if rx_unlikely (!f.IsArray()) { - throw Error(errLogic, "Error parsing protobuf field '%s' - got array, expected scalar %s", f.Name(), f.Type().Name()); - } - if (packed) { - int count = 0; - while (!parser.IsEof()) { - pl.Set(field, convertValueForPayload(pl, field, parser.ReadArrayItem(item.itemType), "protobuf"), true); - ++count; + if (f.IsFloatVector()) { + if (!item.itemType.IsNumeric() || item.itemType.Is()) { + throwUnexpectedArrayTypeForFloatVectorError("protobuf"sv, f); + } + ConstFloatVectorView vectView; + size_t count = 0; + if (!parser.IsEof()) { + auto vect = FloatVector::CreateNotInitialized(f.FloatVectorDimension()); + while (!parser.IsEof()) { + if (count >= size_t(f.FloatVectorDimension())) { + throwUnexpectedArraySizeForFloatVectorError("protobuf"sv, f, count); + } + const Variant value = parser.ReadArrayItem(item.itemType); + vect.RawData()[count] = value.As(); + ++count; + } + if (count != size_t(f.FloatVectorDimension())) { + throwUnexpectedArraySizeForFloatVectorError("protobuf"sv, f, count); + } + floatVectorsHolder.Add(std::move(vect)); + vectView = floatVectorsHolder.Back(); } + pl.Set(field, Variant{vectView}); builder.ArrayRef(item.tagName, field, count); } else { - setValue(pl, builder, item); + if rx_unlikely (!f.IsArray()) { + throwUnexpectedArrayError("protobuf"sv, f); + } + if (packed) { + int count = 0; + while (!parser.IsEof()) { + pl.Set(field, convertValueForPayload(pl, field, parser.ReadArrayItem(item.itemType), "protobuf"), true); + ++count; + } + builder.ArrayRef(item.tagName, field, count); + } else { + setValue(pl, builder, item); + } + validateArrayFieldRestrictions(f, reinterpret_cast(pl.Field(field).p_)->len, "protobuf"); } - validateArrayFieldRestrictions(f, reinterpret_cast(pl.Field(field).p_)->len, "protobuf"); } else { CJsonBuilder& array = arraysStorage_.GetArray(item.tagName); if (packed) { @@ -99,7 +128,7 @@ Error ProtobufDecoder::decodeArray(Payload& pl, CJsonBuilder& builder, const Pro Error status; CJsonProtobufObjectBuilder obj(array, 0, arraysStorage_); while (status.ok() && !parser.IsEof()) { - status = decode(pl, obj, parser.ReadValue()); + status = decode(pl, obj, parser.ReadValue(), floatVectorsHolder); } } else { setValue(pl, array, item); @@ -109,16 +138,16 @@ Error ProtobufDecoder::decodeArray(Payload& pl, CJsonBuilder& builder, const Pro return {}; } -Error ProtobufDecoder::decode(Payload& pl, CJsonBuilder& builder, const ProtobufValue& item) { +Error ProtobufDecoder::decode(Payload& pl, CJsonBuilder& builder, const ProtobufValue& item, FloatVectorsHolderVector& floatVectorsHolder) { TagsPathScope tagScope(tagsPath_, item.tagName); return item.value.Type().EvaluateOneOf( - [&](OneOf) { + [&](OneOf) { setValue(pl, builder, item); return Error{}; }, [&](KeyValueType::String) { if (item.isArray) { - return decodeArray(pl, builder, item); + return decodeArray(pl, builder, item, floatVectorsHolder); } else { return item.itemType.EvaluateOneOf( [&](KeyValueType::String) { @@ -128,36 +157,39 @@ Error ProtobufDecoder::decode(Payload& pl, CJsonBuilder& builder, const Protobuf [&](KeyValueType::Composite) { CJsonProtobufObjectBuilder objBuilder(builder, item.tagName, arraysStorage_); ProtobufObject object(item.As(), *schema_, tagsPath_, tm_); - return decodeObject(pl, objBuilder, object); + return decodeObject(pl, objBuilder, object, floatVectorsHolder); }, - [&](OneOf) { + [&](OneOf) { return Error(errParseProtobuf, "Error parsing length-encoded type: [%s] for field [%s]", item.itemType.Name(), tm_.tag2name(item.tagName)); }); } }, - [&](OneOf) { + [&](OneOf) { return Error(errParseProtobuf, "Unknown field type [%s] while parsing Protobuf", item.value.Type().Name()); }); } -Error ProtobufDecoder::decodeObject(Payload& pl, CJsonBuilder& builder, ProtobufObject& object) { +Error ProtobufDecoder::decodeObject(Payload& pl, CJsonBuilder& builder, ProtobufObject& object, + FloatVectorsHolderVector& floatVectorsHolder) { Error status; ProtobufParser parser(object); while (status.ok() && !parser.IsEof()) { - status = decode(pl, builder, parser.ReadValue()); + status = decode(pl, builder, parser.ReadValue(), floatVectorsHolder); } return status; } -Error ProtobufDecoder::Decode(std::string_view buf, Payload& pl, WrSerializer& wrser) { +Error ProtobufDecoder::Decode(std::string_view buf, Payload& pl, WrSerializer& wrser, FloatVectorsHolderVector& floatVectorsHolder) { try { tagsPath_.clear(); objectScalarIndexes_.reset(); CJsonProtobufObjectBuilder cjsonBuilder(arraysStorage_, wrser, &tm_, 0); ProtobufObject object(buf, *schema_, tagsPath_, tm_); - return decodeObject(pl, cjsonBuilder, object); + return decodeObject(pl, cjsonBuilder, object, floatVectorsHolder); } catch (Error& err) { return err; } diff --git a/cpp_src/core/cjson/protobufdecoder.h b/cpp_src/core/cjson/protobufdecoder.h index 7cba753d7..b6121753a 100644 --- a/cpp_src/core/cjson/protobufdecoder.h +++ b/cpp_src/core/cjson/protobufdecoder.h @@ -9,6 +9,7 @@ namespace reindexer { class Schema; struct ProtobufValue; struct ProtobufObject; +class FloatVectorsHolderVector; class ArraysStorage { public: @@ -75,13 +76,13 @@ class ProtobufDecoder { ProtobufDecoder& operator=(const ProtobufDecoder&) = delete; ProtobufDecoder& operator=(ProtobufDecoder&&) = delete; - Error Decode(std::string_view buf, Payload& pl, WrSerializer& wrser); + Error Decode(std::string_view buf, Payload& pl, WrSerializer& wrser, FloatVectorsHolderVector&); private: void setValue(Payload& pl, CJsonBuilder& builder, ProtobufValue item); - Error decode(Payload& pl, CJsonBuilder& builder, const ProtobufValue& val); - Error decodeObject(Payload& pl, CJsonBuilder& builder, ProtobufObject& object); - Error decodeArray(Payload& pl, CJsonBuilder& builder, const ProtobufValue& val); + Error decode(Payload& pl, CJsonBuilder& builder, const ProtobufValue& val, FloatVectorsHolderVector&); + Error decodeObject(Payload& pl, CJsonBuilder& builder, ProtobufObject& object, FloatVectorsHolderVector&); + Error decodeArray(Payload& pl, CJsonBuilder& builder, const ProtobufValue& val, FloatVectorsHolderVector&); TagsMatcher& tm_; std::shared_ptr schema_; diff --git a/cpp_src/core/cjson/protobufschemabuilder.cc b/cpp_src/core/cjson/protobufschemabuilder.cc index c6ca06c31..1ccdce76c 100644 --- a/cpp_src/core/cjson/protobufschemabuilder.cc +++ b/cpp_src/core/cjson/protobufschemabuilder.cc @@ -90,13 +90,13 @@ void ProtobufSchemaBuilder::Field(std::string_view name, int tagName, const Fiel } writeField(name, typeName, tagName); type.EvaluateOneOf( - [&](OneOf) { + [&](OneOf) { if (ser_) { ser_->Write(" [packed=true]"); } }, [](OneOf) noexcept {}); + KeyValueType::Uuid, KeyValueType::FloatVector>) noexcept {}); } else { writeField(name, typeName, tagName); } diff --git a/cpp_src/core/cjson/recoder.h b/cpp_src/core/cjson/recoder.h new file mode 100644 index 000000000..967f804bd --- /dev/null +++ b/cpp_src/core/cjson/recoder.h @@ -0,0 +1,22 @@ +#pragma once + +#include "core/payload/payloadiface.h" +#include "tagspath.h" + +namespace reindexer { + +class Serializer; +class WrSerializer; + +class Recoder { +public: + [[nodiscard]] virtual TagType Type(TagType oldTagType) = 0; + virtual void Recode(Serializer&, WrSerializer&) const = 0; + virtual void Recode(Serializer&, Payload&, int tagName, WrSerializer&) = 0; + [[nodiscard]] virtual bool Match(int field) const noexcept = 0; + [[nodiscard]] virtual bool Match(const TagsPath&) const = 0; + virtual void Prepare(IdType rowId) noexcept = 0; + virtual ~Recoder() = default; +}; + +} // namespace reindexer diff --git a/cpp_src/core/cjson/tagsmatcherimpl.h b/cpp_src/core/cjson/tagsmatcherimpl.h index 468cdc724..9342f2c6b 100644 --- a/cpp_src/core/cjson/tagsmatcherimpl.h +++ b/cpp_src/core/cjson/tagsmatcherimpl.h @@ -209,7 +209,7 @@ class TagsMatcherImpl { void deserialize(Serializer& ser) { clear(); - size_t cnt = ser.GetVarUint(); + size_t cnt = ser.GetVarUInt(); validateTagSize(cnt); tags2names_.resize(cnt); for (size_t tag = 0; tag < tags2names_.size(); ++tag) { @@ -217,7 +217,6 @@ class TagsMatcherImpl { names2tags_.emplace(name, tag); tags2names_[tag] = name; } - // assert(ser.Eof()); } void deserialize(Serializer& ser, int version, int stateToken) { deserialize(ser); diff --git a/cpp_src/core/cjson/uuid_recoders.h b/cpp_src/core/cjson/uuid_recoders.h index 9bb51c701..7b3c91fc6 100644 --- a/cpp_src/core/cjson/uuid_recoders.h +++ b/cpp_src/core/cjson/uuid_recoders.h @@ -1,14 +1,14 @@ #pragma once -#include "cjsondecoder.h" +#include "recoder.h" namespace reindexer { template -class RecoderUuidToString : public Recoder { +class RecoderUuidToString final : public Recoder { public: explicit RecoderUuidToString(TagsPath tp) noexcept : tagsPath_{std::move(tp)} {} - [[nodiscard]] TagType Type([[maybe_unused]] TagType oldTagType) noexcept override final { + [[nodiscard]] TagType Type([[maybe_unused]] TagType oldTagType) noexcept override { if constexpr (Array) { assertrx(oldTagType == TAG_ARRAY); return TAG_ARRAY; @@ -17,10 +17,11 @@ class RecoderUuidToString : public Recoder { return TAG_STRING; } } - void Recode(Serializer&, WrSerializer&) const override final; - void Recode(Serializer&, Payload&, int, WrSerializer&) override final { assertrx(false); } - [[nodiscard]] bool Match(int) const noexcept final { return false; } - [[nodiscard]] bool Match(const TagsPath& tp) const noexcept final { return tagsPath_ == tp; } + void Recode(Serializer&, WrSerializer&) const override; + void Recode(Serializer&, Payload&, int, WrSerializer&) override { assertrx(false); } + [[nodiscard]] bool Match(int) const noexcept override { return false; } + [[nodiscard]] bool Match(const TagsPath& tp) const noexcept override { return tagsPath_ == tp; } + void Prepare(IdType) noexcept override {} private: TagsPath tagsPath_; @@ -42,20 +43,20 @@ inline void RecoderUuidToString::Recode(Serializer& rdser, WrSerializer& w } } -class RecoderStringToUuidArray : public Recoder { +class RecoderStringToUuidArray final : public Recoder { public: explicit RecoderStringToUuidArray(int f) noexcept : field_{f} {} - [[nodiscard]] TagType Type(TagType oldTagType) override final { + [[nodiscard]] TagType Type(TagType oldTagType) override { fromNotArrayField_ = oldTagType != TAG_ARRAY; if (fromNotArrayField_ && oldTagType != TAG_STRING) { throw Error(errLogic, "Cannot convert not string field to UUID"); } return TAG_ARRAY; } - [[nodiscard]] bool Match(int f) const noexcept final { return f == field_; } - [[nodiscard]] bool Match(const TagsPath&) const noexcept final { return false; } - void Recode(Serializer&, WrSerializer&) const override final { assertrx(false); } - void Recode(Serializer& rdser, Payload& pl, int tagName, WrSerializer& wrser) override final { + [[nodiscard]] bool Match(int f) const noexcept override { return f == field_; } + [[nodiscard]] bool Match(const TagsPath&) const noexcept override { return false; } + void Recode(Serializer&, WrSerializer&) const override { assertrx(false); } + void Recode(Serializer& rdser, Payload& pl, int tagName, WrSerializer& wrser) override { if (fromNotArrayField_) { pl.Set(field_, Variant{rdser.GetStrUuid()}, true); wrser.PutCTag(ctag{TAG_ARRAY, tagName, field_}); @@ -76,6 +77,7 @@ class RecoderStringToUuidArray : public Recoder { wrser.PutVarUint(count); } } + void Prepare(IdType) noexcept override {} private: const int field_{std::numeric_limits::max()}; @@ -83,10 +85,10 @@ class RecoderStringToUuidArray : public Recoder { bool fromNotArrayField_{false}; }; -class RecoderStringToUuid : public Recoder { +class RecoderStringToUuid final : public Recoder { public: explicit RecoderStringToUuid(int f) noexcept : field_{f} {} - [[nodiscard]] TagType Type(TagType oldTagType) override final { + [[nodiscard]] TagType Type(TagType oldTagType) override { if (oldTagType == TAG_ARRAY) { throw Error(errLogic, "Cannot convert array field to not array UUID"); } else if (oldTagType != TAG_STRING) { @@ -94,13 +96,14 @@ class RecoderStringToUuid : public Recoder { } return TAG_UUID; } - [[nodiscard]] bool Match(int f) const noexcept final { return f == field_; } - [[nodiscard]] bool Match(const TagsPath&) const noexcept final { return false; } - void Recode(Serializer&, WrSerializer&) const override final { assertrx(false); } - void Recode(Serializer& rdser, Payload& pl, int tagName, WrSerializer& wrser) override final { + [[nodiscard]] bool Match(int f) const noexcept override { return f == field_; } + [[nodiscard]] bool Match(const TagsPath&) const noexcept override { return false; } + void Recode(Serializer&, WrSerializer&) const override { assertrx(false); } + void Recode(Serializer& rdser, Payload& pl, int tagName, WrSerializer& wrser) override { pl.Set(field_, Variant{rdser.GetStrUuid()}, true); wrser.PutCTag(ctag{TAG_UUID, tagName, field_}); } + void Prepare(IdType) noexcept override {} private: const int field_{std::numeric_limits::max()}; diff --git a/cpp_src/core/clusterproxy.cc b/cpp_src/core/clusterproxy.cc index 23632fcc8..4097388f7 100644 --- a/cpp_src/core/clusterproxy.cc +++ b/cpp_src/core/clusterproxy.cc @@ -1,14 +1,45 @@ #include "clusterproxy.h" +#include "client/itemimplbase.h" +#include "cluster/consts.h" #include "cluster/sharding/shardingcontrolrequest.h" -#include "core/cjson/jsonbuilder.h" -#include "core/defnsconfigs.h" +#include "core/system_ns_names.h" #include "estl/shared_mutex.h" - #include "namespacedef.h" #include "tools/catch_and_return.h" +#include "tools/clusterproxyloghelper.h" namespace reindexer { +#if RX_ENABLE_CLUSTERPROXY_LOGS +template ::value>::type* = nullptr> +static void printErr(const R& r) { + if (!r.ok()) { + clusterProxyLog(LogTrace, "[cluster proxy] Err: %s", r.what()); + } +} + +template ::value>::type* = nullptr> +static void printErr(const R& r) { + if (!r.Status().ok()) { + clusterProxyLog(LogTrace, "[cluster proxy] Tx err: %s", r.Status().what()); + } +} +#endif + +#define CallProxyFunction(Fn) proxyCall + +#define DefFunctor1(P1, F, Action) \ + std::function action = \ + std::bind(&ClusterProxy::Action, this, _1, _2, _3) + +#define DefFunctor2(P1, P2, F, Action) \ + std::function action = \ + std::bind(&ClusterProxy::Action, this, _1, _2, _3, _4) + +#define DefFunctor3(P1, P2, P3, F, Action) \ + std::function action = \ + std::bind(&ClusterProxy::Action, this, _1, _2, _3, _4, _5) + using namespace std::string_view_literals; // This method is for simple modify-requests, proxied by cluster (and for #replicationstats request) @@ -68,7 +99,7 @@ void ClusterProxy::clientToCoreQueryResults(client::QueryResults& clientResults, localTm.deserialize(ser, itemimpl.tagsMatcher().version(), itemimpl.tagsMatcher().stateToken()); } itemimpl.Value().SetLSN(item.GetLSN()); - result.AddItemRef(it.itemParams_.id, itemimpl.Value(), it.itemParams_.proc, it.itemParams_.nsid, true); + result.AddItemRef(it.itemParams_.rank, it.itemParams_.id, itemimpl.Value(), it.itemParams_.nsid, true); result.SaveRawData(std::move(itemimpl)); } } @@ -156,6 +187,319 @@ Error ClusterProxy::Connect(const std::string& dsn, ConnectOpts opts) { return err; } +Error ClusterProxy::OpenNamespace(std::string_view nsName, const StorageOpts& opts, const NsReplicationOpts& replOpts, + const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor3(std::string_view, const StorageOpts&, const NsReplicationOpts&, OpenNamespace, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::OpenNamespace", getServerIDRel()); + return CallProxyFunction(OpenNamespace)(ctx, nsName, action, nsName, opts, replOpts); +} + +Error ClusterProxy::AddNamespace(const NamespaceDef& nsDef, const NsReplicationOpts& replOpts, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor2(const NamespaceDef&, const NsReplicationOpts&, AddNamespace, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::AddNamespace", getServerIDRel()); + return CallProxyFunction(AddNamespace)(ctx, nsDef.name, action, nsDef, replOpts); +} + +Error ClusterProxy::CloseNamespace(std::string_view nsName, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor1(std::string_view, CloseNamespace, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DropNamespace", getServerIDRel()); + return CallProxyFunction(CloseNamespace)(ctx, nsName, action, nsName); +} + +Error ClusterProxy::DropNamespace(std::string_view nsName, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor1(std::string_view, DropNamespace, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DropNamespace", getServerIDRel()); + return CallProxyFunction(DropNamespace)(ctx, nsName, action, nsName); +} + +Error ClusterProxy::TruncateNamespace(std::string_view nsName, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor1(std::string_view, TruncateNamespace, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::TruncateNamespace", getServerIDRel()); + return CallProxyFunction(TruncateNamespace)(ctx, nsName, action, nsName); +} + +Error ClusterProxy::RenameNamespace(std::string_view srcNsName, const std::string& dstNsName, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor2(std::string_view, const std::string&, RenameNamespace, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::RenameNamespace", getServerIDRel()); + return CallProxyFunction(RenameNamespace)(ctx, std::string_view(), action, srcNsName, dstNsName); +} + +Error ClusterProxy::AddIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor2(std::string_view, const IndexDef&, AddIndex, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::AddIndex", getServerIDRel()); + return CallProxyFunction(AddIndex)(ctx, nsName, action, nsName, index); +} + +Error ClusterProxy::UpdateIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor2(std::string_view, const IndexDef&, UpdateIndex, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::UpdateIndex", getServerIDRel()); + return CallProxyFunction(UpdateIndex)(ctx, nsName, action, nsName, index); +} + +Error ClusterProxy::DropIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor2(std::string_view, const IndexDef&, DropIndex, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DropIndex", getServerIDRel()); + return CallProxyFunction(DropIndex)(ctx, nsName, action, nsName, index); +} + +Error ClusterProxy::SetSchema(std::string_view nsName, std::string_view schema, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor2(std::string_view, std::string_view, SetSchema, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::SetSchema", getServerIDRel()); + return CallProxyFunction(SetSchema)(ctx, nsName, action, nsName, schema); +} + +Error ClusterProxy::GetSchema(std::string_view nsName, int format, std::string& schema, const RdxContext& ctx) { + return impl_.GetSchema(nsName, format, schema, ctx); +} + +Error ClusterProxy::EnumNamespaces(std::vector& defs, EnumNamespacesOpts opts, const RdxContext& ctx) { + return impl_.EnumNamespaces(defs, opts, ctx); +} + +Error ClusterProxy::Insert(std::string_view nsName, Item& item, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { + return itemFollowerAction<&client::Reindexer::Insert>(ctx, clientToLeader, nsName, item); + }; + return proxyCall(ctx, nsName, action, nsName, item); +} + +Error ClusterProxy::Insert(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { + return resultItemFollowerAction<&client::Reindexer::Insert>(ctx, clientToLeader, nsName, item, qr); + }; + return proxyCall(ctx, nsName, action, nsName, item, qr); +} + +Error ClusterProxy::Update(std::string_view nsName, Item& item, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { + return itemFollowerAction<&client::Reindexer::Update>(ctx, clientToLeader, nsName, item); + }; + return proxyCall(ctx, nsName, action, nsName, item); +} + +Error ClusterProxy::Update(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { + return resultItemFollowerAction<&client::Reindexer::Update>(ctx, clientToLeader, nsName, item, qr); + }; + return proxyCall(ctx, nsName, action, nsName, item, qr); +} + +Error ClusterProxy::Update(const Query& q, LocalQueryResults& qr, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, const Query& q, LocalQueryResults& qr) { + return resultFollowerAction<&client::Reindexer::Update>(ctx, clientToLeader, q, qr); + }; + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Update query", getServerIDRel()); + return proxyCall(ctx, q.NsName(), action, q, qr); +} + +Error ClusterProxy::Upsert(std::string_view nsName, Item& item, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { + return itemFollowerAction<&client::Reindexer::Upsert>(ctx, clientToLeader, nsName, item); + }; + return proxyCall(ctx, nsName, action, nsName, item); +} + +Error ClusterProxy::Upsert(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { + return resultItemFollowerAction<&client::Reindexer::Upsert>(ctx, clientToLeader, nsName, item, qr); + }; + return proxyCall(ctx, nsName, action, nsName, item, qr); +} + +Error ClusterProxy::Delete(std::string_view nsName, Item& item, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { + return itemFollowerAction<&client::Reindexer::Delete>(ctx, clientToLeader, nsName, item); + }; + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Delete ITEM", getServerIDRel()); + return proxyCall(ctx, nsName, action, nsName, item); +} + +Error ClusterProxy::Delete(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { + return resultItemFollowerAction<&client::Reindexer::Delete>(ctx, clientToLeader, nsName, item, qr); + }; + return proxyCall(ctx, nsName, action, nsName, item, qr); +} + +Error ClusterProxy::Delete(const Query& q, LocalQueryResults& qr, const RdxContext& ctx) { + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, const Query& q, LocalQueryResults& qr) { + return resultFollowerAction<&client::Reindexer::Delete>(ctx, clientToLeader, q, qr); + }; + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Delete QUERY", getServerIDRel()); + return proxyCall(ctx, q.NsName(), action, q, qr); +} + +Error ClusterProxy::Select(const Query& q, LocalQueryResults& qr, const RdxContext& ctx) { + using namespace std::placeholders; + if (!shouldProxyQuery(q)) { + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Select query local", getServerIDRel()); + return impl_.Select(q, qr, ctx); + } + const RdxDeadlineContext deadlineCtx(kReplicationStatsTimeout, ctx.GetCancelCtx()); + const RdxContext rdxDeadlineCtx = ctx.WithCancelCtx(deadlineCtx); + + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, const Query& q, LocalQueryResults& qr) { + return resultFollowerAction<&client::Reindexer::Select>(ctx, clientToLeader, q, qr); + }; + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Select query proxied", getServerIDRel()); + return proxyCall(rdxDeadlineCtx, q.NsName(), action, q, qr); +} + +Transaction ClusterProxy::NewTransaction(std::string_view nsName, const RdxContext& ctx) { + using LocalFT = LocalTransaction (ReindexerImpl::*)(std::string_view, const RdxContext&); + auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName) { + try { + client::Reindexer l = clientToLeader->WithEmmiterServerId(GetServerID()); + return Transaction(impl_.NewTransaction(nsName, ctx), std::move(l)); + } catch (const Error& err) { + return Transaction(err); + } + }; + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::NewTransaction", getServerIDRel()); + return proxyCall(ctx, nsName, action, nsName); +} + +Error ClusterProxy::CommitTransaction(Transaction& tr, QueryResults& qr, bool txExpectsSharding, const RdxContext& ctx) { + return tr.commit(GetServerID(), txExpectsSharding, impl_, qr, ctx); +} + +Error ClusterProxy::RollBackTransaction(Transaction& tr, const RdxContext& ctx) { + // + return tr.rollback(GetServerID(), ctx); +} + +Error ClusterProxy::GetMeta(std::string_view nsName, const std::string& key, std::string& data, const RdxContext& ctx) { + return impl_.GetMeta(nsName, key, data, ctx); +} + +Error ClusterProxy::PutMeta(std::string_view nsName, const std::string& key, std::string_view data, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor3(std::string_view, const std::string&, std::string_view, PutMeta, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::PutMeta", getServerIDRel()); + return CallProxyFunction(PutMeta)(ctx, nsName, action, nsName, key, data); +} + +Error ClusterProxy::EnumMeta(std::string_view nsName, std::vector& keys, const RdxContext& ctx) { + return impl_.EnumMeta(nsName, keys, ctx); +} + +Error ClusterProxy::DeleteMeta(std::string_view nsName, const std::string& key, const RdxContext& ctx) { + using namespace std::placeholders; + DefFunctor2(std::string_view, const std::string&, DeleteMeta, baseFollowerAction); + clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DeleteMeta", getServerIDRel()); + return CallProxyFunction(DeleteMeta)(ctx, nsName, action, nsName, key); +} + +Error ClusterProxy::GetSqlSuggestions(std::string_view sqlQuery, int pos, std::vector& suggestions, const RdxContext& ctx) { + return impl_.GetSqlSuggestions(sqlQuery, pos, suggestions, ctx); +} + +Error ClusterProxy::Status() noexcept { + if (connected_.load(std::memory_order_acquire)) { + return {}; + } + auto st = impl_.Status(); + if (st.ok()) { + return Error(errNotValid, "Reindexer's cluster proxy layer was not initialized properly"); + } + return st; +} + +Error ClusterProxy::GetProtobufSchema(WrSerializer& ser, std::vector& namespaces) { + // + return impl_.GetProtobufSchema(ser, namespaces); +} + +Error ClusterProxy::GetReplState(std::string_view nsName, ReplicationStateV2& state, const RdxContext& ctx) { + return impl_.GetReplState(nsName, state, ctx); +} + +Error ClusterProxy::SetClusterizationStatus(std::string_view nsName, const ClusterizationStatus& status, const RdxContext& ctx) { + return impl_.SetClusterizationStatus(nsName, status, ctx); +} + +Error ClusterProxy::InitSystemNamespaces() { + // + return impl_.InitSystemNamespaces(); +} + +Error ClusterProxy::ApplySnapshotChunk(std::string_view nsName, const SnapshotChunk& ch, const RdxContext& ctx) { + return impl_.ApplySnapshotChunk(nsName, ch, ctx); +} + +Error ClusterProxy::SuggestLeader(const cluster::NodeData& suggestion, cluster::NodeData& response) { + return impl_.SuggestLeader(suggestion, response); +} + +Error ClusterProxy::LeadersPing(const cluster::NodeData& leader) { + Error err = impl_.LeadersPing(leader); + if (err.ok()) { + std::unique_lock lck(processPingEventMutex_); + lastPingLeaderId_ = leader.serverId; + lck.unlock(); + processPingEvent_.notify_all(); + } + return err; +} + +Error ClusterProxy::GetRaftInfo(cluster::RaftInfo& info, const RdxContext& ctx) { + // + return impl_.GetRaftInfo(true, info, ctx); +} + +Error ClusterProxy::CreateTemporaryNamespace(std::string_view baseName, std::string& resultName, const StorageOpts& opts, lsn_t nsVersion, + const RdxContext& ctx) { + return impl_.CreateTemporaryNamespace(baseName, resultName, opts, nsVersion, ctx); +} + +Error ClusterProxy::GetSnapshot(std::string_view nsName, const SnapshotOpts& opts, Snapshot& snapshot, const RdxContext& ctx) { + return impl_.GetSnapshot(nsName, opts, snapshot, ctx); +} + +Error ClusterProxy::SetTagsMatcher(std::string_view nsName, TagsMatcher&& tm, const RdxContext& ctx) { + return impl_.SetTagsMatcher(nsName, std::move(tm), ctx); +} + +Error ClusterProxy::DumpIndex(std::ostream& os, std::string_view nsName, std::string_view index, const RdxContext& ctx) { + return impl_.DumpIndex(os, nsName, index, ctx); +} + +void ClusterProxy::ShutdownCluster() { + impl_.ShutdownCluster(); + clusterConns_.Shutdown(); + resetLeader(); +} + +Namespace::Ptr ClusterProxy::GetNamespacePtr(std::string_view nsName, const RdxContext& ctx) { + // + return impl_.getNamespace(nsName, ctx); +} + +Namespace::Ptr ClusterProxy::GetNamespacePtrNoThrow(std::string_view nsName, const RdxContext& ctx) { + // + return impl_.getNamespaceNoThrow(nsName, ctx); +} + +PayloadType ClusterProxy::GetPayloadType(std::string_view nsName) { + // + return impl_.getPayloadType(nsName); +} + +std::set ClusterProxy::GetFTIndexes(std::string_view nsName) { + // + return impl_.getFTIndexes(nsName); +} + bool ClusterProxy::shouldProxyQuery(const Query& q) { assertrx(q.Type() == QuerySelect); if (kReplicationStatsNamespace != q.NsName()) { @@ -210,6 +554,11 @@ Error ClusterProxy::ResetShardingConfig(std::optional c CATCH_AND_RETURN } +void ClusterProxy::SaveNewShardingConfigFile(const cluster::ShardingConfig& config) const { + // + impl_.saveNewShardingConfigFile(config); +} + #ifdef _MSC_VER #define REINDEXER_FUNC_NAME __FUNCSIG__ #else @@ -267,4 +616,273 @@ Error ClusterProxy::ShardingControlRequest(const sharding::ShardingControlReques return {}; } +Error ClusterProxy::SubscribeUpdates(IEventsObserver& observer, EventSubscriberConfig&& cfg) { + return impl_.SubscribeUpdates(observer, std::move(cfg)); +} + +Error ClusterProxy::UnsubscribeUpdates(IEventsObserver& observer) { + // + return impl_.UnsubscribeUpdates(observer); +} + +void ClusterProxy::ConnectionsMap::SetParams(int clientThreads, int clientConns, int clientConnConcurrency) { + std::lock_guard lck(mtx_); + clientThreads_ = clientThreads > 0 ? clientThreads : cluster::kDefaultClusterProxyConnThreads; + clientConns_ = clientConns > 0 ? (std::min(uint32_t(clientConns), kMaxClusterProxyConnCount)) : cluster::kDefaultClusterProxyConnCount; + clientConnConcurrency_ = clientConnConcurrency > 0 ? (std::min(uint32_t(clientConnConcurrency), kMaxClusterProxyConnConcurrency)) + : cluster::kDefaultClusterProxyCoroPerConn; +} + +std::shared_ptr ClusterProxy::ConnectionsMap::Get(const DSN& dsn) { + { + shared_lock lck(mtx_); + if (shutdown_) { + throw Error(errTerminated, "Proxy is already shut down"); + } + auto found = conns_.find(dsn); + if (found != conns_.end()) { + return found->second; + } + } + + client::ReindexerConfig cfg; + cfg.AppName = "cluster_proxy"; + cfg.EnableCompression = true; + cfg.RequestDedicatedThread = true; + + std::lock_guard lck(mtx_); + cfg.SyncRxCoroCount = clientConnConcurrency_; + if (shutdown_) { + throw Error(errTerminated, "Proxy is already shut down"); + } + auto found = conns_.find(dsn); + if (found != conns_.end()) { + return found->second; + } + auto res = std::make_shared(cfg, clientConns_, clientThreads_); + auto err = res->Connect(dsn); + if (!err.ok()) { + throw err; + } + conns_[dsn] = res; + return res; +} + +void ClusterProxy::ConnectionsMap::Shutdown() { + std::lock_guard lck(mtx_); + shutdown_ = true; + for (auto& conn : conns_) { + conn.second->Stop(); + } +} + +// No forwarding in proxy calls - the same args may be used multiple times in a row +template +R ClusterProxy::localCall(const RdxContext& ctx, Args&... args) { + if constexpr (std::is_same_v) { + return R((impl_.*fn)(args..., ctx)); + } else { + return (impl_.*fn)(args..., ctx); + } +} + +template +R ClusterProxy::proxyCall(const RdxContext& ctx, std::string_view nsName, const FnA& action, Args&... args) { + R r; + Error err; + if (ctx.GetOriginLSN().isEmpty()) { + ErrorCode errCode = errOK; + bool allowCandidateRole = true; + do { + cluster::RaftInfo info; + err = impl_.GetRaftInfo(allowCandidateRole, info, ctx); + if (!err.ok()) { + if (err.code() == errTimeout || err.code() == errCanceled) { + err = Error(err.code(), "Unable to get cluster's leader: %s", err.what()); + } + setErrorCode(r, std::move(err)); + return r; + } + if (info.role == cluster::RaftInfo::Role::None) { // fast way for non-cluster node + if (ctx.HasEmmiterServer()) { + setErrorCode(r, Error(errLogic, "Request was proxied to non-cluster node")); + return r; + } + r = localCall(ctx, args...); + errCode = getErrCode(err, r); + if (errCode == errWrongReplicationData && (!impl_.clusterConfig_ || !impl_.NamespaceIsInClusterConfig(nsName))) { + break; + } + continue; + } + const bool nsInClusterConf = impl_.NamespaceIsInClusterConfig(nsName); + if (!nsInClusterConf) { // ns is not in cluster + clusterProxyLog(LogTrace, "[%d proxy] proxyCall ns not in cluster config (local)", getServerIDRel()); + const bool firstError = (errCode == errOK); + r = localCall(ctx, args...); + errCode = getErrCode(err, r); + if (firstError) { + continue; + } + break; + } + if (info.role == cluster::RaftInfo::Role::Leader) { + resetLeader(); + clusterProxyLog(LogTrace, "[%d proxy] proxyCall RaftInfo::Role::Leader", getServerIDRel()); + r = localCall(ctx, args...); +#if RX_ENABLE_CLUSTERPROXY_LOGS + printErr(r); +#endif + // the only place, where errUpdateReplication may appear + } else if (info.role == cluster::RaftInfo::Role::Follower) { + if (ctx.HasEmmiterServer()) { + setErrorCode(r, Error(errAlreadyProxied, "Request was proxied to follower node")); + return r; + } + try { + auto clientToLeader = getLeader(info); + clusterProxyLog(LogTrace, "[%d proxy] proxyCall RaftInfo::Role::Follower", getServerIDRel()); + r = action(ctx, clientToLeader, args...); + } catch (Error e) { + setErrorCode(r, std::move(e)); + } + } else if (info.role == cluster::RaftInfo::Role::Candidate) { + allowCandidateRole = false; + errCode = errWrongReplicationData; + // Second attempt with awaiting of the role switch + continue; + } + errCode = getErrCode(err, r); + } while (errCode == errWrongReplicationData); + } else { + clusterProxyLog(LogTrace, "[%d proxy] proxyCall LSN not empty (local call)", getServerIDRel()); + r = localCall(ctx, args...); + // errWrongReplicationData means, that leader of the current node doesn't match leader from LSN + if (getErrCode(err, r) == errWrongReplicationData) { + cluster::RaftInfo info; + err = impl_.GetRaftInfo(false, info, ctx); + if (!err.ok()) { + if (err.code() == errTimeout || err.code() == errCanceled) { + err = Error(err.code(), "Unable to get cluster's leader: %s", err.what()); + } + setErrorCode(r, std::move(err)); + return r; + } + if (info.role != cluster::RaftInfo::Role::Follower) { + return r; + } + std::unique_lock lck(processPingEventMutex_); + auto waitRes = processPingEvent_.wait_for(lck, cluster::kLeaderPingInterval * 10); // Awaiting ping from current leader + if (waitRes == std::cv_status::timeout || lastPingLeaderId_ != ctx.GetOriginLSN().Server()) { + return r; + } + lck.unlock(); + return localCall(ctx, args...); + } + } + return r; +} + +template +Error ClusterProxy::baseFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, Args&&... args) { + try { + client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); + const auto ward = ctx.BeforeClusterProxy(); + Error err = (l.*fnl)(std::forward(args)...); + return err; + } catch (const Error& err) { + return err; + } +} + +template +Error ClusterProxy::itemFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { + try { + Error err; + + client::Item clientItem = clientToLeader->NewItem(nsName); + if (clientItem.Status().ok()) { + auto jsonData = item.impl_->GetCJSON(true); + err = clientItem.FromCJSON(jsonData); + if (!err.ok()) { + return err; + } + clientItem.SetPrecepts(item.impl_->GetPrecepts()); + client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); + { + const auto ward = ctx.BeforeClusterProxy(); + err = (l.*fnl)(nsName, clientItem); + } + if (!err.ok()) { + return err; + } + *item.impl_ = ItemImpl(clientItem.impl_->Type(), clientItem.impl_->tagsMatcher()); + err = item.FromCJSON(clientItem.GetCJSON()); + item.setID(clientItem.GetID()); + item.setLSN(clientItem.GetLSN()); + } else { + err = clientItem.Status(); + } + return err; + } catch (const Error& err) { + return err; + } +} + +template +Error ClusterProxy::resultFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, const Query& query, LocalQueryResults& qr) { + try { + Error err; + client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); + client::QueryResults clientResults; + { + const auto ward = ctx.BeforeClusterProxy(); + err = (l.*fnl)(query, clientResults); + } + if (!err.ok()) { + return err; + } + if (!query.GetJoinQueries().empty() || !query.GetMergeQueries().empty() || !query.GetSubQueries().empty()) { + return Error(errLogic, "Unable to proxy query with JOIN, MERGE or SUBQUERY"); + } + clientToCoreQueryResults(clientResults, qr); + return err; + } catch (const Error& err) { + return err; + } +} + +template +Error ClusterProxy::resultItemFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, + LocalQueryResults& qr) { + try { + Error err; + client::Item clientItem = clientToLeader->NewItem(nsName); + if (!clientItem.Status().ok()) { + return clientItem.Status(); + } + auto jsonData = item.impl_->GetCJSON(true); + err = clientItem.FromCJSON(jsonData); + if (!err.ok()) { + return err; + } + clientItem.SetPrecepts(item.impl_->GetPrecepts()); + client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); + client::QueryResults clientResults; + { + const auto ward = ctx.BeforeClusterProxy(); + err = (l.*fnl)(nsName, clientItem, clientResults); + } + if (!err.ok()) { + return err; + } + item.setID(clientItem.GetID()); + item.setLSN(clientItem.GetLSN()); + clientToCoreQueryResults(clientResults, qr); + return err; + } catch (const Error& err) { + return err; + } +} + } // namespace reindexer diff --git a/cpp_src/core/clusterproxy.h b/cpp_src/core/clusterproxy.h index f6986b9ec..560a1b84c 100644 --- a/cpp_src/core/clusterproxy.h +++ b/cpp_src/core/clusterproxy.h @@ -1,12 +1,9 @@ #pragma once #include -#include "client/itemimpl.h" #include "client/reindexer.h" #include "cluster/config.h" -#include "cluster/consts.h" #include "core/reindexer_impl/reindexerimpl.h" -#include "tools/clusterproxyloghelper.h" namespace reindexer { @@ -15,287 +12,82 @@ struct ShardingControlRequestData; struct ShardingControlResponseData; } // namespace sharding -#define CallProxyFunction(Fn) proxyCall - -#define DefFunctor1(P1, F, Action) \ - std::function action = \ - std::bind(&ClusterProxy::Action, this, _1, _2, _3) - -#define DefFunctor2(P1, P2, F, Action) \ - std::function action = \ - std::bind(&ClusterProxy::Action, this, _1, _2, _3, _4) - -#define DefFunctor3(P1, P2, P3, F, Action) \ - std::function action = \ - std::bind(&ClusterProxy::Action, this, _1, _2, _3, _4, _5) - class ClusterProxy { public: ClusterProxy(ReindexerConfig cfg, ActivityContainer& activities, ReindexerImpl::CallbackMap&& proxyCallbacks); ~ClusterProxy(); Error Connect(const std::string& dsn, ConnectOpts opts = ConnectOpts()); - Error OpenNamespace(std::string_view nsName, const StorageOpts& opts, const NsReplicationOpts& replOpts, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor3(std::string_view, const StorageOpts&, const NsReplicationOpts&, OpenNamespace, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::OpenNamespace", getServerIDRel()); - return CallProxyFunction(OpenNamespace)(ctx, nsName, action, nsName, opts, replOpts); - } - Error AddNamespace(const NamespaceDef& nsDef, const NsReplicationOpts& replOpts, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor2(const NamespaceDef&, const NsReplicationOpts&, AddNamespace, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::AddNamespace", getServerIDRel()); - return CallProxyFunction(AddNamespace)(ctx, nsDef.name, action, nsDef, replOpts); - } - Error CloseNamespace(std::string_view nsName, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor1(std::string_view, CloseNamespace, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DropNamespace", getServerIDRel()); - return CallProxyFunction(CloseNamespace)(ctx, nsName, action, nsName); - } - Error DropNamespace(std::string_view nsName, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor1(std::string_view, DropNamespace, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DropNamespace", getServerIDRel()); - return CallProxyFunction(DropNamespace)(ctx, nsName, action, nsName); - } - Error TruncateNamespace(std::string_view nsName, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor1(std::string_view, TruncateNamespace, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::TruncateNamespace", getServerIDRel()); - return CallProxyFunction(TruncateNamespace)(ctx, nsName, action, nsName); - } - Error RenameNamespace(std::string_view srcNsName, const std::string& dstNsName, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor2(std::string_view, const std::string&, RenameNamespace, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::RenameNamespace", getServerIDRel()); - return CallProxyFunction(RenameNamespace)(ctx, std::string_view(), action, srcNsName, dstNsName); - } - Error AddIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor2(std::string_view, const IndexDef&, AddIndex, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::AddIndex", getServerIDRel()); - return CallProxyFunction(AddIndex)(ctx, nsName, action, nsName, index); - } - Error UpdateIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor2(std::string_view, const IndexDef&, UpdateIndex, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::UpdateIndex", getServerIDRel()); - return CallProxyFunction(UpdateIndex)(ctx, nsName, action, nsName, index); - } - Error DropIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor2(std::string_view, const IndexDef&, DropIndex, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DropIndex", getServerIDRel()); - return CallProxyFunction(DropIndex)(ctx, nsName, action, nsName, index); - } - Error SetSchema(std::string_view nsName, std::string_view schema, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor2(std::string_view, std::string_view, SetSchema, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::SetSchema", getServerIDRel()); - return CallProxyFunction(SetSchema)(ctx, nsName, action, nsName, schema); - } - Error GetSchema(std::string_view nsName, int format, std::string& schema, const RdxContext& ctx) { - return impl_.GetSchema(nsName, format, schema, ctx); - } - Error EnumNamespaces(std::vector& defs, EnumNamespacesOpts opts, const RdxContext& ctx) { - return impl_.EnumNamespaces(defs, opts, ctx); - } - Error Insert(std::string_view nsName, Item& item, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { - return itemFollowerAction<&client::Reindexer::Insert>(ctx, clientToLeader, nsName, item); - }; - return proxyCall(ctx, nsName, action, nsName, item); - } - Error Insert(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { - return resultItemFollowerAction<&client::Reindexer::Insert>(ctx, clientToLeader, nsName, item, qr); - }; - return proxyCall(ctx, nsName, action, nsName, item, qr); - } - Error Update(std::string_view nsName, Item& item, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { - return itemFollowerAction<&client::Reindexer::Update>(ctx, clientToLeader, nsName, item); - }; - return proxyCall(ctx, nsName, action, nsName, item); - } - Error Update(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { - return resultItemFollowerAction<&client::Reindexer::Update>(ctx, clientToLeader, nsName, item, qr); - }; - return proxyCall(ctx, nsName, action, nsName, item, qr); - } - Error Update(const Query& q, LocalQueryResults& qr, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, const Query& q, LocalQueryResults& qr) { - return resultFollowerAction<&client::Reindexer::Update>(ctx, clientToLeader, q, qr); - }; - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Update query", getServerIDRel()); - return proxyCall(ctx, q.NsName(), action, q, qr); - } - Error Upsert(std::string_view nsName, Item& item, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { - return itemFollowerAction<&client::Reindexer::Upsert>(ctx, clientToLeader, nsName, item); - }; - return proxyCall(ctx, nsName, action, nsName, item); - } - Error Upsert(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { - return resultItemFollowerAction<&client::Reindexer::Upsert>(ctx, clientToLeader, nsName, item, qr); - }; - return proxyCall(ctx, nsName, action, nsName, item, qr); - } - Error Delete(std::string_view nsName, Item& item, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { - return itemFollowerAction<&client::Reindexer::Delete>(ctx, clientToLeader, nsName, item); - }; - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Delete ITEM", getServerIDRel()); - return proxyCall(ctx, nsName, action, nsName, item); - } - Error Delete(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, LocalQueryResults& qr) { - return resultItemFollowerAction<&client::Reindexer::Delete>(ctx, clientToLeader, nsName, item, qr); - }; - return proxyCall(ctx, nsName, action, nsName, item, qr); - } - Error Delete(const Query& q, LocalQueryResults& qr, const RdxContext& ctx) { - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, const Query& q, LocalQueryResults& qr) { - return resultFollowerAction<&client::Reindexer::Delete>(ctx, clientToLeader, q, qr); - }; - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Delete QUERY", getServerIDRel()); - return proxyCall(ctx, q.NsName(), action, q, qr); - } - Error Select(const Query& q, LocalQueryResults& qr, const RdxContext& ctx) { - using namespace std::placeholders; - if (!shouldProxyQuery(q)) { - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Select query local", getServerIDRel()); - return impl_.Select(q, qr, ctx); - } - const RdxDeadlineContext deadlineCtx(kReplicationStatsTimeout, ctx.GetCancelCtx()); - const RdxContext rdxDeadlineCtx = ctx.WithCancelCtx(deadlineCtx); - - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, const Query& q, LocalQueryResults& qr) { - return resultFollowerAction<&client::Reindexer::Select>(ctx, clientToLeader, q, qr); - }; - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::Select query proxied", getServerIDRel()); - return proxyCall(rdxDeadlineCtx, q.NsName(), action, q, qr); - } + Error OpenNamespace(std::string_view nsName, const StorageOpts& opts, const NsReplicationOpts& replOpts, const RdxContext& ctx); + Error AddNamespace(const NamespaceDef& nsDef, const NsReplicationOpts& replOpts, const RdxContext& ctx); + Error CloseNamespace(std::string_view nsName, const RdxContext& ctx); + Error DropNamespace(std::string_view nsName, const RdxContext& ctx); + Error TruncateNamespace(std::string_view nsName, const RdxContext& ctx); + Error RenameNamespace(std::string_view srcNsName, const std::string& dstNsName, const RdxContext& ctx); + Error AddIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx); + Error UpdateIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx); + Error DropIndex(std::string_view nsName, const IndexDef& index, const RdxContext& ctx); + Error SetSchema(std::string_view nsName, std::string_view schema, const RdxContext& ctx); + Error GetSchema(std::string_view nsName, int format, std::string& schema, const RdxContext& ctx); + Error EnumNamespaces(std::vector& defs, EnumNamespacesOpts opts, const RdxContext& ctx); + Error Insert(std::string_view nsName, Item& item, const RdxContext& ctx); + Error Insert(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx); + Error Update(std::string_view nsName, Item& item, const RdxContext& ctx); + Error Update(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx); + Error Update(const Query& q, LocalQueryResults& qr, const RdxContext& ctx); + Error Upsert(std::string_view nsName, Item& item, const RdxContext& ctx); + Error Upsert(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx); + Error Delete(std::string_view nsName, Item& item, const RdxContext& ctx); + Error Delete(std::string_view nsName, Item& item, LocalQueryResults& qr, const RdxContext& ctx); + Error Delete(const Query& q, LocalQueryResults& qr, const RdxContext& ctx); + Error Select(const Query& q, LocalQueryResults& qr, const RdxContext& ctx); Item NewItem(std::string_view nsName, const RdxContext& ctx) { return impl_.NewItem(nsName, ctx); } - Transaction NewTransaction(std::string_view nsName, const RdxContext& ctx) { - using LocalFT = LocalTransaction (ReindexerImpl::*)(std::string_view, const RdxContext&); - auto action = [this](const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName) { - try { - client::Reindexer l = clientToLeader->WithEmmiterServerId(GetServerID()); - return Transaction(impl_.NewTransaction(nsName, ctx), std::move(l)); - } catch (const Error& err) { - return Transaction(err); - } - }; - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::NewTransaction", getServerIDRel()); - return proxyCall(ctx, nsName, action, nsName); - } - Error CommitTransaction(Transaction& tr, QueryResults& qr, bool txExpectsSharding, const RdxContext& ctx) { - return tr.commit(GetServerID(), txExpectsSharding, impl_, qr, ctx); - } - Error RollBackTransaction(Transaction& tr, const RdxContext& ctx) { return tr.rollback(GetServerID(), ctx); } + Transaction NewTransaction(std::string_view nsName, const RdxContext& ctx); + Error CommitTransaction(Transaction& tr, QueryResults& qr, bool txExpectsSharding, const RdxContext& ctx); + Error RollBackTransaction(Transaction& tr, const RdxContext& ctx); - Error GetMeta(std::string_view nsName, const std::string& key, std::string& data, const RdxContext& ctx) { - return impl_.GetMeta(nsName, key, data, ctx); - } - Error PutMeta(std::string_view nsName, const std::string& key, std::string_view data, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor3(std::string_view, const std::string&, std::string_view, PutMeta, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::PutMeta", getServerIDRel()); - return CallProxyFunction(PutMeta)(ctx, nsName, action, nsName, key, data); - } - Error EnumMeta(std::string_view nsName, std::vector& keys, const RdxContext& ctx) { - return impl_.EnumMeta(nsName, keys, ctx); - } - Error DeleteMeta(std::string_view nsName, const std::string& key, const RdxContext& ctx) { - using namespace std::placeholders; - DefFunctor2(std::string_view, const std::string&, DeleteMeta, baseFollowerAction); - clusterProxyLog(LogTrace, "[%d proxy] ClusterProxy::DeleteMeta", getServerIDRel()); - return CallProxyFunction(DeleteMeta)(ctx, nsName, action, nsName, key); - } + Error GetMeta(std::string_view nsName, const std::string& key, std::string& data, const RdxContext& ctx); + Error PutMeta(std::string_view nsName, const std::string& key, std::string_view data, const RdxContext& ctx); + Error EnumMeta(std::string_view nsName, std::vector& keys, const RdxContext& ctx); + Error DeleteMeta(std::string_view nsName, const std::string& key, const RdxContext& ctx); - Error GetSqlSuggestions(std::string_view sqlQuery, int pos, std::vector& suggestions, const RdxContext& ctx) { - return impl_.GetSqlSuggestions(sqlQuery, pos, suggestions, ctx); - } - Error Status() noexcept { - if (connected_.load(std::memory_order_acquire)) { - return {}; - } - auto st = impl_.Status(); - if (st.ok()) { - return Error(errNotValid, "Reindexer's cluster proxy layer was not initialized properly"); - } - return st; - } - Error GetProtobufSchema(WrSerializer& ser, std::vector& namespaces) { return impl_.GetProtobufSchema(ser, namespaces); } - Error GetReplState(std::string_view nsName, ReplicationStateV2& state, const RdxContext& ctx) { - return impl_.GetReplState(nsName, state, ctx); - } - Error SetClusterizationStatus(std::string_view nsName, const ClusterizationStatus& status, const RdxContext& ctx) { - return impl_.SetClusterizationStatus(nsName, status, ctx); - } + Error GetSqlSuggestions(std::string_view sqlQuery, int pos, std::vector& suggestions, const RdxContext& ctx); + Error Status() noexcept; + Error GetProtobufSchema(WrSerializer& ser, std::vector& namespaces); + Error GetReplState(std::string_view nsName, ReplicationStateV2& state, const RdxContext& ctx); + Error SetClusterizationStatus(std::string_view nsName, const ClusterizationStatus& status, const RdxContext& ctx); bool NeedTraceActivity() const noexcept { return impl_.NeedTraceActivity(); } - Error InitSystemNamespaces() { return impl_.InitSystemNamespaces(); } - Error ApplySnapshotChunk(std::string_view nsName, const SnapshotChunk& ch, const RdxContext& ctx) { - return impl_.ApplySnapshotChunk(nsName, ch, ctx); - } + Error InitSystemNamespaces(); + Error ApplySnapshotChunk(std::string_view nsName, const SnapshotChunk& ch, const RdxContext& ctx); - Error SuggestLeader(const cluster::NodeData& suggestion, cluster::NodeData& response) { - return impl_.SuggestLeader(suggestion, response); - } - Error LeadersPing(const cluster::NodeData& leader) { - Error err = impl_.LeadersPing(leader); - if (err.ok()) { - std::unique_lock lck(processPingEventMutex_); - lastPingLeaderId_ = leader.serverId; - lck.unlock(); - processPingEvent_.notify_all(); - } - return err; - } - Error GetRaftInfo(cluster::RaftInfo& info, const RdxContext& ctx) { return impl_.GetRaftInfo(true, info, ctx); } + Error SuggestLeader(const cluster::NodeData& suggestion, cluster::NodeData& response); + Error LeadersPing(const cluster::NodeData& leader); + Error GetRaftInfo(cluster::RaftInfo& info, const RdxContext& ctx); Error CreateTemporaryNamespace(std::string_view baseName, std::string& resultName, const StorageOpts& opts, lsn_t nsVersion, - const RdxContext& ctx) { - return impl_.CreateTemporaryNamespace(baseName, resultName, opts, nsVersion, ctx); - } - Error GetSnapshot(std::string_view nsName, const SnapshotOpts& opts, Snapshot& snapshot, const RdxContext& ctx) { - return impl_.GetSnapshot(nsName, opts, snapshot, ctx); - } + const RdxContext& ctx); + Error GetSnapshot(std::string_view nsName, const SnapshotOpts& opts, Snapshot& snapshot, const RdxContext& ctx); Error ClusterControlRequest(const ClusterControlRequestData& request) { return impl_.ClusterControlRequest(request); } - Error SetTagsMatcher(std::string_view nsName, TagsMatcher&& tm, const RdxContext& ctx) { - return impl_.SetTagsMatcher(nsName, std::move(tm), ctx); - } - Error DumpIndex(std::ostream& os, std::string_view nsName, std::string_view index, const RdxContext& ctx) { - return impl_.DumpIndex(os, nsName, index, ctx); - } - void ShutdownCluster() { - impl_.ShutdownCluster(); - clusterConns_.Shutdown(); - resetLeader(); - } + Error SetTagsMatcher(std::string_view nsName, TagsMatcher&& tm, const RdxContext& ctx); + Error DumpIndex(std::ostream& os, std::string_view nsName, std::string_view index, const RdxContext& ctx); + void ShutdownCluster(); intrusive_ptr> GetShardingConfig() const noexcept { return impl_.shardingConfig_.Get(); } - Namespace::Ptr GetNamespacePtr(std::string_view nsName, const RdxContext& ctx) { return impl_.getNamespace(nsName, ctx); } - Namespace::Ptr GetNamespacePtrNoThrow(std::string_view nsName, const RdxContext& ctx) { return impl_.getNamespaceNoThrow(nsName, ctx); } + Namespace::Ptr GetNamespacePtr(std::string_view nsName, const RdxContext& ctx); + Namespace::Ptr GetNamespacePtrNoThrow(std::string_view nsName, const RdxContext& ctx); - PayloadType GetPayloadType(std::string_view nsName) { return impl_.getPayloadType(nsName); } - std::set GetFTIndexes(std::string_view nsName) { return impl_.getFTIndexes(nsName); } + PayloadType GetPayloadType(std::string_view nsName); + std::set GetFTIndexes(std::string_view nsName); Error ResetShardingConfig(std::optional config = std::nullopt) noexcept; - void SaveNewShardingConfigFile(const cluster::ShardingConfig& config) const { impl_.saveNewShardingConfigFile(config); } + void SaveNewShardingConfigFile(const cluster::ShardingConfig& config) const; Error ShardingControlRequest(const sharding::ShardingControlRequestData& request, sharding::ShardingControlResponseData& response, const RdxContext& ctx) noexcept; - Error SubscribeUpdates(IEventsObserver& observer, EventSubscriberConfig&& cfg) { - return impl_.SubscribeUpdates(observer, std::move(cfg)); - } - Error UnsubscribeUpdates(IEventsObserver& observer) { return impl_.UnsubscribeUpdates(observer); } + Error SubscribeUpdates(IEventsObserver& observer, EventSubscriberConfig&& cfg); + Error UnsubscribeUpdates(IEventsObserver& observer); // REINDEX_WITH_V3_FOLLOWERS Error SubscribeUpdates(IUpdatesObserverV3* observer, const UpdatesFilters& filters, SubscriptionOpts opts) { @@ -320,56 +112,9 @@ class ClusterProxy { class ConnectionsMap { public: - void SetParams(int clientThreads, int clientConns, int clientConnConcurrency) { - std::lock_guard lck(mtx_); - clientThreads_ = clientThreads > 0 ? clientThreads : cluster::kDefaultClusterProxyConnThreads; - clientConns_ = - clientConns > 0 ? (std::min(uint32_t(clientConns), kMaxClusterProxyConnCount)) : cluster::kDefaultClusterProxyConnCount; - clientConnConcurrency_ = clientConnConcurrency > 0 - ? (std::min(uint32_t(clientConnConcurrency), kMaxClusterProxyConnConcurrency)) - : cluster::kDefaultClusterProxyCoroPerConn; - } - std::shared_ptr Get(const DSN& dsn) { - { - shared_lock lck(mtx_); - if (shutdown_) { - throw Error(errTerminated, "Proxy is already shut down"); - } - auto found = conns_.find(dsn); - if (found != conns_.end()) { - return found->second; - } - } - - client::ReindexerConfig cfg; - cfg.AppName = "cluster_proxy"; - cfg.EnableCompression = true; - cfg.RequestDedicatedThread = true; - - std::lock_guard lck(mtx_); - cfg.SyncRxCoroCount = clientConnConcurrency_; - if (shutdown_) { - throw Error(errTerminated, "Proxy is already shut down"); - } - auto found = conns_.find(dsn); - if (found != conns_.end()) { - return found->second; - } - auto res = std::make_shared(cfg, clientConns_, clientThreads_); - auto err = res->Connect(dsn); - if (!err.ok()) { - throw err; - } - conns_[dsn] = res; - return res; - } - void Shutdown() { - std::lock_guard lck(mtx_); - shutdown_ = true; - for (auto& conn : conns_) { - conn.second->Stop(); - } - } + void SetParams(int clientThreads, int clientConns, int clientConnConcurrency); + std::shared_ptr Get(const DSN& dsn); + void Shutdown(); private: shared_timed_mutex mtx_; @@ -410,226 +155,19 @@ class ClusterProxy { template Error shardingControlRequestAction(const RdxContext& ctx, Args&&... args) noexcept; -#if RX_ENABLE_CLUSTERPROXY_LOGS - - template ::value>::type* = nullptr> - static void printErr(const R& r) { - if (!r.ok()) { - clusterProxyLog(LogTrace, "[cluster proxy] Err: %s", r.what()); - } - } - - template ::value>::type* = nullptr> - static void printErr(const R& r) { - if (!r.Status().ok()) { - clusterProxyLog(LogTrace, "[cluster proxy] Tx err: %s", r.Status().what()); - } - } - -#endif - template - R localCall(const RdxContext& ctx, Args&&... args) { - if constexpr (std::is_same_v) { - return R((impl_.*fn)(std::forward(args)..., ctx)); - } else { - return (impl_.*fn)(std::forward(args)..., ctx); - } - } + R localCall(const RdxContext& ctx, Args&... args); template - R proxyCall(const RdxContext& ctx, std::string_view nsName, const FnA& action, Args&&... args) { - R r; - Error err; - if (ctx.GetOriginLSN().isEmpty()) { - ErrorCode errCode = errOK; - bool allowCandidateRole = true; - do { - cluster::RaftInfo info; - err = impl_.GetRaftInfo(allowCandidateRole, info, ctx); - if (!err.ok()) { - if (err.code() == errTimeout || err.code() == errCanceled) { - err = Error(err.code(), "Unable to get cluster's leader: %s", err.what()); - } - setErrorCode(r, std::move(err)); - return r; - } - if (info.role == cluster::RaftInfo::Role::None) { // fast way for non-cluster node - if (ctx.HasEmmiterServer()) { - setErrorCode(r, Error(errLogic, "Request was proxied to non-cluster node")); - return r; - } - r = localCall(ctx, std::forward(args)...); - errCode = getErrCode(err, r); - if (errCode == errWrongReplicationData && (!impl_.clusterConfig_ || !impl_.NamespaceIsInClusterConfig(nsName))) { - break; - } - continue; - } - const bool nsInClusterConf = impl_.NamespaceIsInClusterConfig(nsName); - if (!nsInClusterConf) { // ns is not in cluster - clusterProxyLog(LogTrace, "[%d proxy] proxyCall ns not in cluster config (local)", getServerIDRel()); - const bool firstError = (errCode == errOK); - r = localCall(ctx, std::forward(args)...); - errCode = getErrCode(err, r); - if (firstError) { - continue; - } - break; - } - if (info.role == cluster::RaftInfo::Role::Leader) { - resetLeader(); - clusterProxyLog(LogTrace, "[%d proxy] proxyCall RaftInfo::Role::Leader", getServerIDRel()); - r = localCall(ctx, std::forward(args)...); -#if RX_ENABLE_CLUSTERPROXY_LOGS - printErr(r); -#endif - // the only place, where errUpdateReplication may appear - } else if (info.role == cluster::RaftInfo::Role::Follower) { - if (ctx.HasEmmiterServer()) { - setErrorCode(r, Error(errAlreadyProxied, "Request was proxied to follower node")); - return r; - } - try { - auto clientToLeader = getLeader(info); - clusterProxyLog(LogTrace, "[%d proxy] proxyCall RaftInfo::Role::Follower", getServerIDRel()); - r = action(ctx, clientToLeader, std::forward(args)...); - } catch (Error e) { - setErrorCode(r, std::move(e)); - } - } else if (info.role == cluster::RaftInfo::Role::Candidate) { - allowCandidateRole = false; - errCode = errWrongReplicationData; - // Second attempt with awaiting of the role switch - continue; - } - errCode = getErrCode(err, r); - } while (errCode == errWrongReplicationData); - } else { - clusterProxyLog(LogTrace, "[%d proxy] proxyCall LSN not empty (local call)", getServerIDRel()); - r = localCall(ctx, std::forward(args)...); - // errWrongReplicationData means, that leader of the current node doesn't match leader from LSN - if (getErrCode(err, r) == errWrongReplicationData) { - cluster::RaftInfo info; - err = impl_.GetRaftInfo(false, info, ctx); - if (!err.ok()) { - if (err.code() == errTimeout || err.code() == errCanceled) { - err = Error(err.code(), "Unable to get cluster's leader: %s", err.what()); - } - setErrorCode(r, std::move(err)); - return r; - } - if (info.role != cluster::RaftInfo::Role::Follower) { - return r; - } - std::unique_lock lck(processPingEventMutex_); - auto waitRes = processPingEvent_.wait_for(lck, cluster::kLeaderPingInterval * 10); // Awaiting ping from current leader - if (waitRes == std::cv_status::timeout || lastPingLeaderId_ != ctx.GetOriginLSN().Server()) { - return r; - } - lck.unlock(); - return localCall(ctx, std::forward(args)...); - } - } - return r; - } + R proxyCall(const RdxContext& ctx, std::string_view nsName, const FnA& action, Args&... args); template - Error baseFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, Args&&... args) { - try { - client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); - const auto ward = ctx.BeforeClusterProxy(); - Error err = (l.*fnl)(std::forward(args)...); - return err; - } catch (const Error& err) { - return err; - } - } + Error baseFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, Args&&... args); template - Error itemFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item) { - try { - Error err; - - client::Item clientItem = clientToLeader->NewItem(nsName); - if (clientItem.Status().ok()) { - auto jsonData = item.impl_->GetCJSON(true); - err = clientItem.FromCJSON(jsonData); - if (!err.ok()) { - return err; - } - clientItem.SetPrecepts(item.impl_->GetPrecepts()); - client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); - { - const auto ward = ctx.BeforeClusterProxy(); - err = (l.*fnl)(nsName, clientItem); - } - if (!err.ok()) { - return err; - } - *item.impl_ = ItemImpl(clientItem.impl_->Type(), clientItem.impl_->tagsMatcher()); - err = item.FromCJSON(clientItem.GetCJSON()); - item.setID(clientItem.GetID()); - item.setLSN(clientItem.GetLSN()); - } else { - err = clientItem.Status(); - } - return err; - } catch (const Error& err) { - return err; - } - } + Error itemFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item); template - Error resultFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, const Query& query, LocalQueryResults& qr) { - try { - Error err; - client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); - client::QueryResults clientResults; - { - const auto ward = ctx.BeforeClusterProxy(); - err = (l.*fnl)(query, clientResults); - } - if (!err.ok()) { - return err; - } - if (!query.GetJoinQueries().empty() || !query.GetMergeQueries().empty() || !query.GetSubQueries().empty()) { - return Error(errLogic, "Unable to proxy query with JOIN, MERGE or SUBQUERY"); - } - clientToCoreQueryResults(clientResults, qr); - return err; - } catch (const Error& err) { - return err; - } - } + Error resultFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, const Query& query, LocalQueryResults& qr); template Error resultItemFollowerAction(const RdxContext& ctx, LeaderRefT clientToLeader, std::string_view nsName, Item& item, - LocalQueryResults& qr) { - try { - Error err; - client::Item clientItem = clientToLeader->NewItem(nsName); - if (!clientItem.Status().ok()) { - return clientItem.Status(); - } - auto jsonData = item.impl_->GetCJSON(true); - err = clientItem.FromCJSON(jsonData); - if (!err.ok()) { - return err; - } - clientItem.SetPrecepts(item.impl_->GetPrecepts()); - client::Reindexer l = clientToLeader->WithEmmiterServerId(sId_); - client::QueryResults clientResults; - { - const auto ward = ctx.BeforeClusterProxy(); - err = (l.*fnl)(nsName, clientItem, clientResults); - } - if (!err.ok()) { - return err; - } - item.setID(clientItem.GetID()); - item.setLSN(clientItem.GetLSN()); - clientToCoreQueryResults(clientResults, qr); - return err; - } catch (const Error& err) { - return err; - } - } + LocalQueryResults& qr); void clientToCoreQueryResults(client::QueryResults&, LocalQueryResults&); bool shouldProxyQuery(const Query& q); diff --git a/cpp_src/core/dbconfig.cc b/cpp_src/core/dbconfig.cc index 7c08f6204..de715ab29 100644 --- a/cpp_src/core/dbconfig.cc +++ b/cpp_src/core/dbconfig.cc @@ -1,13 +1,10 @@ #include "dbconfig.h" -#include #include -#include #include "cjson/jsonbuilder.h" #include "estl/smart_lock.h" #include "gason/gason.h" -#include "spdlog/fmt/fmt.h" #include "tools/jsontools.h" #include "tools/logger.h" #include "tools/serializer.h" @@ -404,6 +401,7 @@ Error NamespaceConfigData::FromJSON(const gason::JsonNode& v) { err = tryReadOptionalJsonValue(&errorString, v, "start_copy_policy_tx_size"sv, startCopyPolicyTxSize); err = tryReadOptionalJsonValue(&errorString, v, "copy_policy_multiplier"sv, copyPolicyMultiplier); err = tryReadOptionalJsonValue(&errorString, v, "tx_size_to_always_copy"sv, txSizeToAlwaysCopy); + err = tryReadOptionalJsonValue(&errorString, v, "tx_vec_insertion_threads"sv, txVecInsertionThreads); err = tryReadOptionalJsonValue(&errorString, v, "optimization_timeout_ms"sv, optimizationTimeout); err = tryReadOptionalJsonValue(&errorString, v, "optimization_sort_workers"sv, optimizationSortWorkers); (void)err; // ignored; Errors will be handled with errorString diff --git a/cpp_src/core/dbconfig.h b/cpp_src/core/dbconfig.h index 81f6eb31a..8b4b7a03e 100644 --- a/cpp_src/core/dbconfig.h +++ b/cpp_src/core/dbconfig.h @@ -3,7 +3,6 @@ #include #include #include -#include #include "cluster/config.h" #include "estl/fast_hash_map.h" #include "estl/shared_mutex.h" @@ -109,6 +108,7 @@ struct NamespaceConfigData { int startCopyPolicyTxSize = 10'000; int copyPolicyMultiplier = 5; int txSizeToAlwaysCopy = 100'000; + int txVecInsertionThreads = 4; int optimizationTimeout = 800; int optimizationSortWorkers = 4; int64_t walSize = 4'000'000; @@ -118,6 +118,7 @@ struct NamespaceConfigData { int64_t maxIterationsIdSetPreResult = 20'000; bool idxUpdatesCountingMode = false; int syncStorageFlushLimit = 20'000; + int annStorageCacheBuildTimeout = 5'000; NamespaceCacheConfigData cacheConfig; Error FromJSON(const gason::JsonNode& v); diff --git a/cpp_src/core/defnsconfigs.h b/cpp_src/core/defnsconfigs.h index c8e4fa4c0..c707871cf 100644 --- a/cpp_src/core/defnsconfigs.h +++ b/cpp_src/core/defnsconfigs.h @@ -1,18 +1,10 @@ #pragma once #include "namespacedef.h" +#include "system_ns_names.h" namespace reindexer { -constexpr char kPerfStatsNamespace[] = "#perfstats"; -constexpr char kQueriesPerfStatsNamespace[] = "#queriesperfstats"; -constexpr char kMemStatsNamespace[] = "#memstats"; -constexpr char kNamespacesNamespace[] = "#namespaces"; -constexpr char kConfigNamespace[] = "#config"; -constexpr char kActivityStatsNamespace[] = "#activitystats"; -constexpr char kClientsStatsNamespace[] = "#clientsstats"; -constexpr char kClusterConfigNamespace[] = "#clusterconfig"; -const std::string_view kReplicationStatsNamespace = "#replicationstats"; constexpr char kNsNameField[] = "name"; constexpr std::string_view kDefDBConfig[] = { @@ -52,6 +44,7 @@ constexpr std::string_view kDefDBConfig[] = { "start_copy_policy_tx_size":10000, "copy_policy_multiplier":5, "tx_size_to_always_copy":100000, + "tx_vec_insertion_threads":4, "optimization_timeout_ms":800, "optimization_sort_workers":4, "wal_size":4000000, @@ -71,7 +64,8 @@ constexpr std::string_view kDefDBConfig[] = { "joins_preselect_hit_to_cache":2, "query_count_cache_size":134217728, "query_count_hit_to_cache":2 - } + }, + "ann_storage_cache_build_timeout_ms": 5000 } ] })json", @@ -182,7 +176,7 @@ const NamespaceDef kSystemNsDefs[] = { .AddIndex("client_version", "-", "string", IndexOpts().Dense()) .AddIndex("app_name", "-", "string", IndexOpts().Dense()) .AddIndex("tx_count", "-", "int64", IndexOpts().Dense()), - NamespaceDef(std::string(kReplicationStatsNamespace), StorageOpts()) + NamespaceDef(kReplicationStatsNamespace, StorageOpts()) .AddIndex("type", "hash", "string", IndexOpts().PK()) .AddIndex("update_drops", "-", "int64", IndexOpts().Dense()) .AddIndex("pending_updates_count", "-", "int64", IndexOpts().Dense()) diff --git a/cpp_src/core/enums.h b/cpp_src/core/enums.h new file mode 100644 index 000000000..0d7876438 --- /dev/null +++ b/cpp_src/core/enums.h @@ -0,0 +1,50 @@ +#pragma once + +#include + +namespace reindexer { + +#define BOOL_ENUM(Name) \ + class [[nodiscard]] Name { \ + public: \ + constexpr explicit Name(bool v) noexcept : value_{v} {} \ + constexpr Name& operator|=(bool other) & noexcept { \ + value_ |= other; \ + return *this; \ + } \ + constexpr Name& operator&=(bool other) & noexcept { \ + value_ &= other; \ + return *this; \ + } \ + constexpr Name operator!() const noexcept { return Name{!value_}; } \ + constexpr Name operator||(Name other) const noexcept { return Name{value_ || other.value_}; } \ + constexpr Name operator&&(Name other) const noexcept { return Name{value_ && other.value_}; } \ + [[nodiscard]] constexpr bool operator==(Name other) const noexcept { return value_ == other.value_; } \ + [[nodiscard]] constexpr bool operator!=(Name other) const noexcept { return !operator==(other); } \ + [[nodiscard]] explicit constexpr operator bool() const noexcept { return value_; } \ + [[nodiscard]] constexpr bool operator*() const noexcept { return value_; } \ + \ + private: \ + bool value_; \ + }; \ + static constexpr Name Name##_True = Name(true); \ + static constexpr Name Name##_False = Name(false); + +BOOL_ENUM(IsRanked) +BOOL_ENUM(ContainRanked) +BOOL_ENUM(ForcedFirst) +BOOL_ENUM(CheckUnsigned) +BOOL_ENUM(NeedSort) +BOOL_ENUM(IsMergeQuery) +BOOL_ENUM(ReplaceDeleted) + +#undef BOOL_ENUM + +enum class [[nodiscard]] VectorMetric { L2, InnerProduct, Cosine }; +enum class [[nodiscard]] RankedTypeQuery { NotSet, No, FullText, KnnL2, KnnIP, KnnCos }; +enum class [[nodiscard]] RankSortType { RankOnly, RankAndID, ExternalExpression }; +enum class [[nodiscard]] RankOrdering { Off, Asc, Desc }; + +enum class [[nodiscard]] FloatVectorDimension : uint16_t { Zero = 0 }; + +} // namespace reindexer diff --git a/cpp_src/core/expressiontree.h b/cpp_src/core/expressiontree.h index d6275f886..2a4090c4e 100644 --- a/cpp_src/core/expressiontree.h +++ b/cpp_src/core/expressiontree.h @@ -125,13 +125,17 @@ class ExpressionTree { Node(Node&& other) noexcept : storage_{std::move(other.storage_)}, operation{std::move(other.operation)} {} ~Node() = default; RX_ALWAYS_INLINE Node& operator=(const Node& other) { - storage_ = other.storage_; - operation = other.operation; + if (this != &other) { + storage_ = other.storage_; + operation = other.operation; + } return *this; } RX_ALWAYS_INLINE Node& operator=(Node&& other) noexcept { - storage_ = std::move(other.storage_); - operation = std::move(other.operation); + if (this != &other) { + storage_ = std::move(other.storage_); + operation = std::move(other.operation); + } return *this; } RX_ALWAYS_INLINE bool operator==(const Node& other) const { diff --git a/cpp_src/core/formatters/jsonstring_fmt.h b/cpp_src/core/formatters/jsonstring_fmt.h index 29fba5d90..0dff3ef47 100644 --- a/cpp_src/core/formatters/jsonstring_fmt.h +++ b/cpp_src/core/formatters/jsonstring_fmt.h @@ -1,16 +1,20 @@ #pragma once #include "fmt/format.h" +#include "fmt/printf.h" #include "gason/gason.h" template <> -struct fmt::printf_formatter { - template - constexpr auto parse(ContextT& ctx) { - return ctx.begin(); - } +struct fmt::formatter : public fmt::formatter { template auto format(const gason::JsonString& s, ContextT& ctx) const { - return fmt::format_to(ctx.out(), "{}", std::string_view(s)); + return fmt::formatter::format(s, ctx); } }; + +namespace fmt { +template <> +inline auto formatter::format(const gason::JsonString& s, fmt::basic_printf_context& ctx) const { + return fmt::format_to(ctx.out(), "{}", std::string_view(s)); +} +} // namespace fmt diff --git a/cpp_src/core/formatters/key_string_fmt.h b/cpp_src/core/formatters/key_string_fmt.h index 765b78efd..7bb0a8be5 100644 --- a/cpp_src/core/formatters/key_string_fmt.h +++ b/cpp_src/core/formatters/key_string_fmt.h @@ -2,18 +2,7 @@ #include "core/keyvalue/key_string.h" #include "fmt/format.h" - -template <> -struct fmt::printf_formatter { - template - constexpr auto parse(ContextT& ctx) { - return ctx.begin(); - } - template - auto format(const reindexer::key_string& s, ContextT& ctx) const { - return s ? fmt::format_to(ctx.out(), "{}", std::string_view(s)) : fmt::format_to(ctx.out(), ""); - } -}; +#include "fmt/printf.h" template <> struct fmt::formatter : public fmt::formatter { @@ -22,3 +11,10 @@ struct fmt::formatter : public fmt::formatter::format(std::string_view(s), ctx) : fmt::format_to(ctx.out(), ""); } }; + +namespace fmt { +template <> +inline auto formatter::format(const reindexer::key_string& s, fmt::basic_printf_context& ctx) const { + return s ? fmt::format_to(ctx.out(), "{}", std::string_view(s)) : fmt::format_to(ctx.out(), ""); +} +} // namespace fmt diff --git a/cpp_src/core/formatters/namespacesname_fmt.h b/cpp_src/core/formatters/namespacesname_fmt.h index 4d42adfbb..5b5eae899 100644 --- a/cpp_src/core/formatters/namespacesname_fmt.h +++ b/cpp_src/core/formatters/namespacesname_fmt.h @@ -2,18 +2,7 @@ #include "core/namespace/namespacename.h" #include "fmt/format.h" - -template <> -struct fmt::printf_formatter { - template - constexpr auto parse(ContextT& ctx) { - return ctx.begin(); - } - template - auto format(const reindexer::NamespaceName& name, ContextT& ctx) const { - return fmt::format_to(ctx.out(), "{}", name.OriginalName()); - } -}; +#include "fmt/printf.h" template <> struct fmt::formatter : public fmt::formatter { @@ -22,3 +11,10 @@ struct fmt::formatter : public fmt::formatter::format(name.OriginalName(), ctx); } }; + +namespace fmt { +template <> +inline auto formatter::format(const reindexer::NamespaceName& name, fmt::basic_printf_context& ctx) const { + return fmt::format_to(ctx.out(), "{}", name.OriginalName()); +} +} // namespace fmt diff --git a/cpp_src/core/formatters/uuid_fmt.h b/cpp_src/core/formatters/uuid_fmt.h index e249d4b55..8fd46cd94 100644 --- a/cpp_src/core/formatters/uuid_fmt.h +++ b/cpp_src/core/formatters/uuid_fmt.h @@ -2,18 +2,7 @@ #include "core/keyvalue/uuid.h" #include "fmt/format.h" - -template <> -struct fmt::printf_formatter { - template - constexpr auto parse(ContextT& ctx) { - return ctx.begin(); - } - template - auto format(const reindexer::Uuid& uuid, ContextT& ctx) const { - return fmt::format_to(ctx.out(), "'{}'", std::string(uuid)); - } -}; +#include "fmt/printf.h" template <> struct fmt::formatter : public fmt::formatter { @@ -22,3 +11,10 @@ struct fmt::formatter : public fmt::formatter { return fmt::formatter::format(std::string(uuid), ctx); } }; + +namespace fmt { +template <> +inline auto formatter::format(const reindexer::Uuid& uuid, fmt::basic_printf_context& ctx) const { + return fmt::format_to(ctx.out(), "'{}'", std::string(uuid)); +} +} // namespace fmt diff --git a/cpp_src/core/ft/areaholder.h b/cpp_src/core/ft/areaholder.h index 3bb410fa8..d3446e8c3 100644 --- a/cpp_src/core/ft/areaholder.h +++ b/cpp_src/core/ft/areaholder.h @@ -1,9 +1,5 @@ #pragma once -#include -#include -#include -#include -#include "estl/h_vector.h" + #include "sort/pdqsort.hpp" #include "usingcontainer.h" diff --git a/cpp_src/core/ft/config/baseftconfig.cc b/cpp_src/core/ft/config/baseftconfig.cc index 4a46c1522..a020b5cf4 100644 --- a/cpp_src/core/ft/config/baseftconfig.cc +++ b/cpp_src/core/ft/config/baseftconfig.cc @@ -29,9 +29,9 @@ void BaseFTConfig::parseBase(const gason::JsonNode& root) { for (auto& sw : stopWordsNode) { std::string word; StopWord::Type type = StopWord::Type::Stop; - if (sw.value.getTag() == gason::JsonTag::JSON_STRING) { + if (sw.value.getTag() == gason::JsonTag::STRING) { word = sw.As(); - } else if (sw.value.getTag() == gason::JsonTag::JSON_OBJECT) { + } else if (sw.value.getTag() == gason::JsonTag::OBJECT) { word = sw["word"].As(); type = sw["is_morpheme"].As() ? StopWord::Type::Morpheme : StopWord::Type::Stop; } diff --git a/cpp_src/core/ft/config/ftfastconfig.cc b/cpp_src/core/ft/config/ftfastconfig.cc index 65c5ac5f4..4ed6212be 100644 --- a/cpp_src/core/ft/config/ftfastconfig.cc +++ b/cpp_src/core/ft/config/ftfastconfig.cc @@ -1,9 +1,8 @@ #include "ftfastconfig.h" #include -#include #include #include "core/cjson/jsonbuilder.h" -#include "core/ft/typos.h" +#include "core/ft/limits.h" #include "tools/errors.h" #include "tools/jsontools.h" @@ -52,7 +51,7 @@ void FtFastConfig::parse(std::string_view json, const RHashMap maxTypos = 2 * root["max_typos_in_word"].As<>(MaxTyposInWord(), 0, kMaxTyposInWord); } else { const auto& maxTyposNode = root["max_typos"]; - if (!maxTyposNode.empty() && maxTyposNode.value.getTag() != gason::JSON_NUMBER) { + if (!maxTyposNode.empty() && maxTyposNode.value.getTag() != gason::JsonTag::NUMBER) { throw Error(errParseDSL, "Fulltext configuration field 'max_typos' should be integer"); } maxTypos = maxTyposNode.As<>(maxTypos, 0, 2 * kMaxTyposInWord); diff --git a/cpp_src/core/ft/filters/synonyms.cc b/cpp_src/core/ft/filters/synonyms.cc index 6c5413bb6..b74ba00a3 100644 --- a/cpp_src/core/ft/filters/synonyms.cc +++ b/cpp_src/core/ft/filters/synonyms.cc @@ -163,7 +163,7 @@ void Synonyms::SetConfig(BaseFTConfig* cfg) { for (const std::wstring& singleAlt : *singleAlternatives) { multAlt->push_back({singleAlt}); } - many2any_.push_back({{resultOfSplit.begin(), resultOfSplit.end()}, std::move(multAlt)}); + many2any_.emplace_back(RVector{resultOfSplit.begin(), resultOfSplit.end()}, std::move(multAlt)); } if (!multipleAlternatives->empty()) { if (singleAlternatives->empty()) { diff --git a/cpp_src/core/ft/filters/translit.cc b/cpp_src/core/ft/filters/translit.cc index c8a5a0e38..8903a4d01 100644 --- a/cpp_src/core/ft/filters/translit.cc +++ b/cpp_src/core/ft/filters/translit.cc @@ -1,7 +1,7 @@ #include "translit.h" #include #include -#include "estl/span.h" +#include namespace reindexer { diff --git a/cpp_src/core/ft/ft_fast/dataholder.cc b/cpp_src/core/ft/ft_fast/dataholder.cc index fad988b72..30398418e 100644 --- a/cpp_src/core/ft/ft_fast/dataholder.cc +++ b/cpp_src/core/ft/ft_fast/dataholder.cc @@ -1,7 +1,7 @@ #include "dataholder.h" #include +#include "core/ft/ft_fast/frisosplitter.h" #include "dataprocessor.h" -#include "selecter.h" namespace reindexer { diff --git a/cpp_src/core/ft/ft_fast/dataholder.h b/cpp_src/core/ft/ft_fast/dataholder.h index 39b93ffff..31a5e4c2b 100644 --- a/cpp_src/core/ft/ft_fast/dataholder.h +++ b/cpp_src/core/ft/ft_fast/dataholder.h @@ -1,17 +1,14 @@ #pragma once #include #include -#include "core/ft/areaholder.h" #include "core/ft/config/ftfastconfig.h" #include "core/ft/filters/itokenfilter.h" -#include "core/ft/ft_fast/frisosplitter.h" #include "core/ft/ft_fast/splitter.h" #include "core/ft/idrelset.h" #include "core/ft/limits.h" #include "core/ft/stemmer.h" #include "core/ft/typos.h" #include "core/ft/usingcontainer.h" -#include "core/index/ft_preselect.h" #include "core/index/indextext/ftkeyentry.h" #include "estl/flat_str_map.h" #include "estl/suffix_map.h" diff --git a/cpp_src/core/ft/ft_fast/dataprocessor.cc b/cpp_src/core/ft/ft_fast/dataprocessor.cc index 3cfe4bce7..b30afe879 100644 --- a/cpp_src/core/ft/ft_fast/dataprocessor.cc +++ b/cpp_src/core/ft/ft_fast/dataprocessor.cc @@ -121,8 +121,12 @@ size_t DataProcessor::commitIdRelSets(const WordsVector& preprocWords, w idsetcnt += sizeof(*wIt); } - word->vids.insert(word->vids.end(), std::make_move_iterator(keyIt->second.vids_.begin()), - std::make_move_iterator(keyIt->second.vids_.end())); + if constexpr (std::is_same_v) { + word->vids.insert(word->vids.end(), keyIt->second.vids_.begin(), keyIt->second.vids_.end()); + } else { + word->vids.insert(word->vids.end(), std::make_move_iterator(keyIt->second.vids_.begin()), + std::make_move_iterator(keyIt->second.vids_.end())); + } keyIt->second.vids_ = IdRelSet(); word->vids.shrink_to_fit(); idsetcnt += word->vids.heap_size(); diff --git a/cpp_src/core/ft/ft_fast/indextexttypes.h b/cpp_src/core/ft/ft_fast/indextexttypes.h index 8b2972e60..ae99edc65 100644 --- a/cpp_src/core/ft/ft_fast/indextexttypes.h +++ b/cpp_src/core/ft/ft_fast/indextexttypes.h @@ -1,6 +1,8 @@ #pragma once +#include #include "core/ft/limits.h" +#include "tools/assertrx.h" namespace reindexer { diff --git a/cpp_src/core/ft/ft_fast/selecter.cc b/cpp_src/core/ft/ft_fast/selecter.cc index 14d2d4795..4ee40eb59 100644 --- a/cpp_src/core/ft/ft_fast/selecter.cc +++ b/cpp_src/core/ft/ft_fast/selecter.cc @@ -120,8 +120,8 @@ void Selector::prepareVariants(std::vector& variants, RV // RX_NO_INLINE just for build test purpose. Do not expect any effect here template template -MergeType Selector::Process(FtDSLQuery&& dsl, bool inTransaction, FtSortType ftSortType, FtMergeStatuses::Statuses&& mergeStatuses, - const RdxContext& rdxCtx) { +MergeType Selector::Process(FtDSLQuery&& dsl, bool inTransaction, RankSortType rankSortType, + FtMergeStatuses::Statuses&& mergeStatuses, const RdxContext& rdxCtx) { FtSelectContext ctx; ctx.rawResults.reserve(dsl.size()); // STEP 2: Search dsl terms for each variant @@ -213,10 +213,10 @@ MergeType Selector::Process(FtDSLQuery&& dsl, bool inTransaction, FtSort const auto maxMergedSize = std::min(size_t(holder_.cfg_->mergeLimit), ctx.totalORVids); if (maxMergedSize < 0xFFFF) { - return mergeResultsBmType(std::move(results), ctx.totalORVids, synonymsBounds, inTransaction, ftSortType, + return mergeResultsBmType(std::move(results), ctx.totalORVids, synonymsBounds, inTransaction, rankSortType, std::move(mergeStatuses), rdxCtx); } else if (maxMergedSize < 0xFFFFFFFF) { - return mergeResultsBmType(std::move(results), ctx.totalORVids, synonymsBounds, inTransaction, ftSortType, + return mergeResultsBmType(std::move(results), ctx.totalORVids, synonymsBounds, inTransaction, rankSortType, std::move(mergeStatuses), rdxCtx); } else { assertrx_throw(false); @@ -227,18 +227,18 @@ MergeType Selector::Process(FtDSLQuery&& dsl, bool inTransaction, FtSort template template MergeType Selector::mergeResultsBmType(std::vector&& results, size_t totalORVids, - const std::vector& synonymsBounds, bool inTransaction, FtSortType ftSortType, + const std::vector& synonymsBounds, bool inTransaction, RankSortType rankSortType, FtMergeStatuses::Statuses&& mergeStatuses, const RdxContext& rdxCtx) { switch (holder_.cfg_->bm25Config.bm25Type) { case FtFastConfig::Bm25Config::Bm25Type::rx: return mergeResults(std::move(results), totalORVids, synonymsBounds, inTransaction, - ftSortType, std::move(mergeStatuses), rdxCtx); + rankSortType, std::move(mergeStatuses), rdxCtx); case FtFastConfig::Bm25Config::Bm25Type::classic: return mergeResults(std::move(results), totalORVids, synonymsBounds, inTransaction, - ftSortType, std::move(mergeStatuses), rdxCtx); + rankSortType, std::move(mergeStatuses), rdxCtx); case FtFastConfig::Bm25Config::Bm25Type::wordCount: return mergeResults(std::move(results), totalORVids, synonymsBounds, inTransaction, - ftSortType, std::move(mergeStatuses), rdxCtx); + rankSortType, std::move(mergeStatuses), rdxCtx); } assertrx_throw(false); return MergeType(); @@ -1364,7 +1364,7 @@ bool Selector::TyposHandler::isWordFitMaxLettPerm(const std::string_view template template MergedType Selector::mergeResults(std::vector&& rawResults, size_t maxMergedSize, - const std::vector& synonymsBounds, bool inTransaction, FtSortType ftSortType, + const std::vector& synonymsBounds, bool inTransaction, RankSortType rankSortType, FtMergeStatuses::Statuses&& mergeStatuses, const RdxContext& rdxCtx) { const auto& vdocs = holder_.vdocs_; @@ -1471,17 +1471,17 @@ MergedType Selector::mergeResults(std::vector&& rawRe merged.maxRank = m.proc; } } - switch (ftSortType) { - case FtSortType::RankOnly: { + switch (rankSortType) { + case RankSortType::RankOnly: { boost::sort::pdqsort_branchless(merged.begin(), merged.end(), [](const MergeInfo& lhs, const MergeInfo& rhs) noexcept { return lhs.proc > rhs.proc; }); return merged; } - case FtSortType::RankAndID: { + case RankSortType::RankAndID: { return merged; } - case FtSortType::ExternalExpression: - throw Error(errLogic, "FtSortType::ExternalExpression not implemented."); + case RankSortType::ExternalExpression: + throw Error(errLogic, "RankSortType::ExternalExpression not implemented."); break; } return merged; @@ -1524,37 +1524,37 @@ void Selector::printVariants(const FtSelectContext& ctx, const TextSearc } template class Selector; -template MergeDataBase Selector::Process(FtDSLQuery&&, bool, FtSortType, +template MergeDataBase Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeData Selector::Process>(FtDSLQuery&&, bool, FtSortType, +template MergeData Selector::Process>(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); template MergeData Selector::Process>(FtDSLQuery&&, bool, - FtSortType, + RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeDataBase Selector::Process(FtDSLQuery&&, bool, FtSortType, +template MergeDataBase Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeData Selector::Process(FtDSLQuery&&, bool, FtSortType, +template MergeData Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeData Selector::Process(FtDSLQuery&&, bool, FtSortType, +template MergeData Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); template class Selector; -template MergeDataBase Selector::Process(FtDSLQuery&&, bool, FtSortType, FtMergeStatuses::Statuses&&, +template MergeDataBase Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeData Selector::Process(FtDSLQuery&&, bool, FtSortType, FtMergeStatuses::Statuses&&, +template MergeData Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeData Selector::Process(FtDSLQuery&&, bool, FtSortType, +template MergeData Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeDataBase Selector::Process(FtDSLQuery&&, bool, FtSortType, FtMergeStatuses::Statuses&&, +template MergeDataBase Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); -template MergeData Selector::Process(FtDSLQuery&&, bool, FtSortType, FtMergeStatuses::Statuses&&, - const RdxContext&); -template MergeData Selector::Process(FtDSLQuery&&, bool, FtSortType, +template MergeData Selector::Process(FtDSLQuery&&, bool, RankSortType, + FtMergeStatuses::Statuses&&, const RdxContext&); +template MergeData Selector::Process(FtDSLQuery&&, bool, RankSortType, FtMergeStatuses::Statuses&&, const RdxContext&); } // namespace reindexer diff --git a/cpp_src/core/ft/ft_fast/selecter.h b/cpp_src/core/ft/ft_fast/selecter.h index 689312b93..77c1eb832 100644 --- a/cpp_src/core/ft/ft_fast/selecter.h +++ b/cpp_src/core/ft/ft_fast/selecter.h @@ -1,7 +1,9 @@ #pragma once +#include "core/enums.h" +#include "core/ft/areaholder.h" #include "core/ft/ftdsl.h" #include "core/ft/idrelset.h" -#include "core/selectfunc/ctx/ftctx.h" +#include "core/index/ft_preselect.h" #include "dataholder.h" namespace reindexer { @@ -63,8 +65,7 @@ class Selector { }; template - MergeType Process(FtDSLQuery&& dsl, bool inTransaction, FtSortType ftSortType, FtMergeStatuses::Statuses&& mergeStatuses, - const RdxContext&); + MergeType Process(FtDSLQuery&& dsl, bool inTransaction, RankSortType, FtMergeStatuses::Statuses&& mergeStatuses, const RdxContext&); private: struct TextSearchResult { @@ -206,7 +207,7 @@ class Selector { template MergeType mergeResults(std::vector&& rawResults, size_t totalORVids, const std::vector& synonymsBounds, - bool inTransaction, FtSortType ftSortType, FtMergeStatuses::Statuses&& mergeStatuses, const RdxContext&); + bool inTransaction, RankSortType, FtMergeStatuses::Statuses&& mergeStatuses, const RdxContext&); template void mergeIteration(TextSearchResults& rawRes, index_t rawResIndex, FtMergeStatuses::Statuses& mergeStatuses, MergeType& merged, @@ -290,8 +291,7 @@ class Selector { template MergeType mergeResultsBmType(std::vector&& results, size_t totalORVids, const std::vector& synonymsBounds, - bool inTransaction, FtSortType ftSortType, FtMergeStatuses::Statuses&& mergeStatuses, - const RdxContext& rdxCtx); + bool inTransaction, RankSortType, FtMergeStatuses::Statuses&& mergeStatuses, const RdxContext& rdxCtx); void debugMergeStep(const char* msg, int vid, float normBm25, float normDist, int finalRank, int prevRank); template diff --git a/cpp_src/core/ft/ft_fuzzy/baseseacher.cc b/cpp_src/core/ft/ft_fuzzy/baseseacher.cc index 9b1789dd4..3c43775d0 100644 --- a/cpp_src/core/ft/ft_fuzzy/baseseacher.cc +++ b/cpp_src/core/ft/ft_fuzzy/baseseacher.cc @@ -1,11 +1,13 @@ #include "baseseacher.h" -#include -#include -#include #include "core/ft/ft_fuzzy/advacedpackedvec.h" #include "core/ft/ftdsl.h" #include "core/rdxcontext.h" #include "tools/stringstools.h" + +#ifdef FULL_LOG_FT +#include +#endif + namespace search_engine { using namespace reindexer; diff --git a/cpp_src/core/ft/ft_fuzzy/baseseacher.h b/cpp_src/core/ft/ft_fuzzy/baseseacher.h index eec4bc078..18878c98d 100644 --- a/cpp_src/core/ft/ft_fuzzy/baseseacher.h +++ b/cpp_src/core/ft/ft_fuzzy/baseseacher.h @@ -16,7 +16,7 @@ namespace search_engine { class BaseSearcher { public: - void AddSeacher(ITokenFilter::Ptr&& seacher); + void AddSeacher(reindexer::ITokenFilter::Ptr&& seacher); void AddIndex(BaseHolder::Ptr& holder, std::string_view src_data, const IdType id, int field, const std::string& extraWordSymbols); SearchResult Compare(const BaseHolder::Ptr& holder, const reindexer::FtDSLQuery& dsl, bool inTransaction, const reindexer::RdxContext&); @@ -30,10 +30,10 @@ class BaseSearcher { std::pair GetData(const BaseHolder::Ptr& holder, unsigned int i, wchar_t* buf, const wchar_t* src_data, size_t data_size); size_t ParseData(const BaseHolder::Ptr& holder, const std::wstring& src_data, int& max_id, int& min_id, - std::vector& rusults, const FtDslOpts& opts, double proc); + std::vector& rusults, const reindexer::FtDslOpts& opts, double proc); void AddIdToInfo(Info* info, const IdType id, std::pair pos, uint32_t total_size); uint32_t FindHash(const std::wstring& data); - std::vector> searchers_; + std::vector> searchers_; }; } // namespace search_engine diff --git a/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.cc b/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.cc index 936f81c98..2cf3b2691 100644 --- a/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.cc +++ b/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.cc @@ -1,5 +1,7 @@ #include "basebuildedholder.h" +using namespace reindexer; + namespace search_engine { DIt BaseHolder::GetData(const wchar_t* key) { diff --git a/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.h b/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.h index febe60736..5372a7a30 100644 --- a/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.h +++ b/cpp_src/core/ft/ft_fuzzy/dataholder/basebuildedholder.h @@ -1,10 +1,6 @@ #pragma once #include -#include #include -#include -#include -#include #include "core/ft/config/ftfuzzyconfig.h" #include "core/ft/ft_fuzzy/advacedpackedvec.h" #include "core/ft/idrelset.h" @@ -14,11 +10,9 @@ #include "tools/customhash.h" namespace search_engine { -using namespace reindexer; - #ifndef DEBUG_FT struct DataStructHash { - inline size_t operator()(const std::wstring& ent) const noexcept { return Hash(ent); } + inline size_t operator()(const std::wstring& ent) const noexcept { return reindexer::Hash(ent); } }; struct DataStructEQ { inline bool operator()(const std::wstring& ent, const std::wstring& ent1) const noexcept { return ent == ent1; } @@ -28,7 +22,7 @@ struct DataStructLess { }; template using data_map = tsl::hopscotch_map; -typedef fast_hash_set data_set; +typedef reindexer::fast_hash_set data_set; #else struct DataStructHash { @@ -37,10 +31,10 @@ struct DataStructHash { template using data_map = fast_hash_map; -typedef fast_hash_set data_set; +typedef reindexer::fast_hash_set data_set; #endif -typedef data_map::iterator DIt; -typedef fast_hash_map> word_size_map; +typedef data_map::iterator DIt; +typedef reindexer::fast_hash_map> word_size_map; class BaseHolder { public: @@ -53,7 +47,7 @@ class BaseHolder { BaseHolder& operator=(BaseHolder&&) noexcept = delete; void ClearTemp() { - data_map tmp_data; + data_map tmp_data; tmp_data_.swap(tmp_data); } DIt end() { return data_.end(); } @@ -62,17 +56,17 @@ class BaseHolder { ClearTemp(); data_.clear(); } - void SetConfig(const std::unique_ptr& cfg) { cfg_ = *cfg.get(); } + void SetConfig(const std::unique_ptr& cfg) { cfg_ = *cfg.get(); } DIt GetData(const wchar_t* key); - void SetSize(uint32_t size, VDocIdType id, int filed); - void AddDada(const wchar_t* key, VDocIdType id, int pos, int field); + void SetSize(uint32_t size, reindexer::VDocIdType id, int filed); + void AddDada(const wchar_t* key, reindexer::VDocIdType id, int pos, int field); void Commit(); public: - data_map tmp_data_; - data_map data_; + data_map tmp_data_; + data_map data_; word_size_map words_; - FtFuzzyConfig cfg_; + reindexer::FtFuzzyConfig cfg_; }; } // namespace search_engine diff --git a/cpp_src/core/ft/ft_fuzzy/dumper/fulltextdumper.cc b/cpp_src/core/ft/ft_fuzzy/dumper/fulltextdumper.cc index db88797b0..3a102186d 100644 --- a/cpp_src/core/ft/ft_fuzzy/dumper/fulltextdumper.cc +++ b/cpp_src/core/ft/ft_fuzzy/dumper/fulltextdumper.cc @@ -21,7 +21,8 @@ void FullTextDumper::LogFinalData(const reindexer::LocalQueryResults& result) { std::vector tmp_buffer; tmp_buffer.push_back("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"); tmp_buffer.push_back("Returned ids: "); - for (const auto& res : result.Items()) { + for (const auto& it : result.Items()) { + const auto& res = it.GetItemRef(); tmp_buffer.push_back("id: " + std::to_string(res.Id()) + " | lsn: " + std::to_string(int64_t(res.Value().GetLSN()))); } tmp_buffer.push_back("_______________________________________"); diff --git a/cpp_src/core/ft/ft_fuzzy/merger/basemerger.h b/cpp_src/core/ft/ft_fuzzy/merger/basemerger.h index c86300b04..9337b0646 100644 --- a/cpp_src/core/ft/ft_fuzzy/merger/basemerger.h +++ b/cpp_src/core/ft/ft_fuzzy/merger/basemerger.h @@ -14,12 +14,12 @@ class RdxContext; namespace search_engine { struct IDCtx { - const RVector* data; + const reindexer::RVector* data; int pos; double* max_proc; size_t total_size; - const FtDslOpts* opts; - const FtFuzzyConfig& cfg; + const reindexer::FtDslOpts* opts; + const reindexer::FtFuzzyConfig& cfg; double proc; word_size_map* sizes; }; @@ -55,18 +55,16 @@ struct SearchResult { double max_proc_; }; -using namespace reindexer; - struct FirstResult { - const AdvacedPackedVec* data; - const FtDslOpts* opts; + const reindexer::AdvacedPackedVec* data; + const reindexer::FtDslOpts* opts; int pos; double proc; }; struct MergeCtx { std::vector* results; - const FtFuzzyConfig* cfg; + const reindexer::FtFuzzyConfig* cfg; size_t total_size; word_size_map* sizes; }; @@ -75,7 +73,7 @@ class BaseMerger { public: BaseMerger(int max_id, int min_id); - SearchResult Merge(MergeCtx& ctx, bool inTransaction, const RdxContext&); + SearchResult Merge(MergeCtx& ctx, bool inTransaction, const reindexer::RdxContext&); private: int max_id_; diff --git a/cpp_src/core/ft/ft_fuzzy/searchengine.cc b/cpp_src/core/ft/ft_fuzzy/searchengine.cc index c91e6cb8c..00e70f596 100644 --- a/cpp_src/core/ft/ft_fuzzy/searchengine.cc +++ b/cpp_src/core/ft/ft_fuzzy/searchengine.cc @@ -1,11 +1,12 @@ #include "searchengine.h" #include -#include #include #include #include "core/ft/filters/kblayout.h" #include "core/ft/filters/translit.h" +using namespace reindexer; + namespace search_engine { SearchEngine::SearchEngine() { diff --git a/cpp_src/core/ft/ft_fuzzy/searchengine.h b/cpp_src/core/ft/ft_fuzzy/searchengine.h index 5a0a4eca3..90563855b 100644 --- a/cpp_src/core/ft/ft_fuzzy/searchengine.h +++ b/cpp_src/core/ft/ft_fuzzy/searchengine.h @@ -1,7 +1,5 @@ #pragma once #include -#include -#include #include "baseseacher.h" #include "core/ft/config/ftfuzzyconfig.h" #include "core/ft/ftdsl.h" @@ -18,12 +16,12 @@ class SearchEngine { typedef std::shared_ptr Ptr; SearchEngine(); - void SetConfig(const std::unique_ptr& cfg); + void SetConfig(const std::unique_ptr& cfg); SearchEngine(const SearchEngine& rhs) = delete; SearchEngine& operator=(const SearchEngine&) = delete; - SearchResult Search(const FtDSLQuery& dsl, bool inTransaction, const reindexer::RdxContext&); + SearchResult Search(const reindexer::FtDSLQuery& dsl, bool inTransaction, const reindexer::RdxContext&); void Rebuild(); void AddData(std::string_view src_data, const IdType id, int field, const std::string& extraWordSymbols); void Commit(); diff --git a/cpp_src/core/ft/ftdsl.h b/cpp_src/core/ft/ftdsl.h index 3ef26922e..4056e56f3 100644 --- a/cpp_src/core/ft/ftdsl.h +++ b/cpp_src/core/ft/ftdsl.h @@ -2,7 +2,6 @@ #include #include -#include #include "core/type_consts.h" #include "estl/h_vector.h" #include "stopwords/types.h" @@ -38,6 +37,10 @@ struct FtDSLEntry { FtDslOpts opts; }; +#if !defined(__clang__) && !defined(_MSC_VER) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif struct FtDSLVariant { FtDSLVariant() = default; FtDSLVariant(std::wstring p, int pr) noexcept : pattern{std::move(p)}, proc{pr} {} @@ -45,6 +48,9 @@ struct FtDSLVariant { std::wstring pattern; int proc = 0; }; +#if !defined(__clang__) && !defined(_MSC_VER) +#pragma GCC diagnostic pop +#endif struct StopWord; diff --git a/cpp_src/core/ft/idrelset.h b/cpp_src/core/ft/idrelset.h index fc967e460..747ecfe72 100644 --- a/cpp_src/core/ft/idrelset.h +++ b/cpp_src/core/ft/idrelset.h @@ -2,8 +2,6 @@ #include #include -#include -#include #include "estl/packed_vector.h" #include "sort/pdqsort.hpp" #include "usingcontainer.h" diff --git a/cpp_src/core/ft/usingcontainer.h b/cpp_src/core/ft/usingcontainer.h index 4edfe1549..d8b4c1bd4 100644 --- a/cpp_src/core/ft/usingcontainer.h +++ b/cpp_src/core/ft/usingcontainer.h @@ -1,10 +1,13 @@ #pragma once + +// #define REINDEX_FT_EXTRA_DEBUG + +#ifdef REINDEX_FT_EXTRA_DEBUG #include +#endif // REINDEX_FT_EXTRA_DEBUG #include "estl/fast_hash_map.h" #include "estl/h_vector.h" -// #define REINDEX_FT_EXTRA_DEBUG - namespace reindexer { #ifdef REINDEX_FT_EXTRA_DEBUG diff --git a/cpp_src/core/idset.h b/cpp_src/core/idset.h index 167370cc6..84b8ed628 100644 --- a/cpp_src/core/idset.h +++ b/cpp_src/core/idset.h @@ -8,7 +8,7 @@ #include "cpp-btree/btree_set.h" #include "estl/h_vector.h" #include "estl/intrusive_ptr.h" -#include "estl/span.h" +#include #include "sort/pdqsort.hpp" namespace reindexer { @@ -75,7 +75,6 @@ class IdSetPlain : protected base_idset { void ReserveForSorted(int sortedIdxCount) { reserve(size() * (sortedIdxCount + 1)); } std::string Dump() const; -protected: IdSetPlain(base_idset&& idset) noexcept : base_idset(std::move(idset)) {} }; @@ -148,6 +147,11 @@ class IdSet : public IdSetPlain { push_back(id); } + void SetUnordered(IdSetPlain&& other) { + assertrx(!set_); + IdSetPlain::operator=(std::move(other)); + } + template void Append(InputIt first, InputIt last, EditMode editMode = Auto) { if (editMode == Unordered) { @@ -234,7 +238,7 @@ class IdSet : public IdSetPlain { std::atomic usingBtree_; }; -using IdSetRef = span; -using IdSetCRef = span; +using IdSetRef = std::span; +using IdSetCRef = std::span; } // namespace reindexer diff --git a/cpp_src/core/index/float_vector/float_vector_index.cc b/cpp_src/core/index/float_vector/float_vector_index.cc new file mode 100644 index 000000000..ae5f9116f --- /dev/null +++ b/cpp_src/core/index/float_vector/float_vector_index.cc @@ -0,0 +1,96 @@ +#include "float_vector_index.h" +#include "tools/assertrx.h" + +namespace reindexer { + +FloatVectorIndex::FloatVectorIndex(const IndexDef& idef, PayloadType&& pt, FieldsSet&& fields) + : Index{idef, std::move(pt), std::move(fields)} { + assertrx(idef.Opts().IsFloatVector()); // TODO _dbg + assert(!idef.Opts().IsArray()); // TODO remove this + keyType_ = selectKeyType_ = KeyValueType::FloatVector{}; + memStat_.name = name_; + metric_ = idef.Opts().FloatVector().Metric(); +} + +void FloatVectorIndex::Delete(const VariantArray& keys, IdType id, StringsHolder& stringsHolder, bool& clearCache) { + assertrx(keys.size() == 1); // TODO _dbg + const auto it = emptyValues_.find(id); + if (it == emptyValues_.end()) { + Delete(keys[0], id, stringsHolder, clearCache); + } else { + emptyValues_.erase(it); + memStat_.uniqKeysCount = emptyValues_.empty() ? 0 : 1; + } +} + +SelectKeyResults FloatVectorIndex::SelectKey(const VariantArray&, CondType, SortType, SelectOpts, const BaseFunctionCtx::Ptr&, + const RdxContext&) { + throw_as_assert; +} + +void FloatVectorIndex::Upsert(VariantArray& result, const VariantArray& keys, IdType id, bool& clearCache) { + assertrx(keys.size() == 1); // TODO _dbg + result.emplace_back(Upsert(keys[0], id, clearCache)); +} + +Variant FloatVectorIndex::Upsert(const Variant& key, IdType id, bool& clearCache) { + const ConstFloatVectorView vect{key}; + if (vect.IsEmpty()) { + emptyValues_.insert(id); + memStat_.uniqKeysCount = 1; + return Variant{ConstFloatVectorView{}}; + } + if (vect.Dimension() != Dimension()) { + throw Error{errNotValid, "Attempt to upsert vector of dimension %d in a float vector index of dimension %d", + size_t(vect.Dimension()), size_t(Dimension())}; + } + return upsert(vect, id, clearCache); +} + +SelectKeyResult FloatVectorIndex::Select(ConstFloatVectorView key, const KnnSearchParams& p, KnnCtx& ctx) const { + if (key.IsEmpty()) { + throw Error{errNotValid, "Attempt to search knn by empty float vector"}; + } + if (key.Dimension() != Dimension()) { + throw Error{errNotValid, "Attempt to search knn by float vector of dimension %d in float vector index of dimension %d", + size_t(key.Dimension()), size_t(Dimension())}; + } + return select(key, p, ctx); +} + +IndexMemStat FloatVectorIndex::GetMemStat(const RdxContext&) noexcept { + memStat_.indexingStructSize = emptyValues_.allocated_mem_size(); + return memStat_; +} + +FloatVector FloatVectorIndex::GetFloatVector(IdType id) const { + const auto it = emptyValues_.find(id); + if (it == emptyValues_.end()) { + return getFloatVector(id); + } else { + return {}; + } +} + +ConstFloatVectorView FloatVectorIndex::GetFloatVectorView(IdType id) const { + const auto it = emptyValues_.find(id); + if (it == emptyValues_.end()) { + return getFloatVectorView(id); + } else { + return {}; + } +} + +void FloatVectorIndex::WriterBase::writePK(IdType id) { + VariantArray pks = getPK_(id); + if (!isCompositePK_) { + ser_.PutVariant(pks[0]); + } else { + ser_.PutVarUint(pks.size()); + for (auto& v : pks) { + ser_.PutVariant(v); + } + } +} + +} // namespace reindexer diff --git a/cpp_src/core/index/float_vector/float_vector_index.h b/cpp_src/core/index/float_vector/float_vector_index.h new file mode 100644 index 000000000..dc3dc4a5f --- /dev/null +++ b/cpp_src/core/index/float_vector/float_vector_index.h @@ -0,0 +1,113 @@ +#pragma once + +#include "core/index/index.h" + +namespace reindexer { + +class KnnCtx; +class KnnSearchParams; + +class IPKWirter { +public: + virtual void AppendPKByID(IdType, WrSerializer&) = 0; +}; + +class FloatVectorIndex : public Index { +public: + using VecDataGetterF = std::function; + using PKGetterF = std::function; + +protected: + class WriterBase { + protected: + WriterBase(WrSerializer& ser, PKGetterF&& getPK, bool isCompositePK) noexcept + : ser_{ser}, getPK_{std::move(getPK)}, isCompositePK_{isCompositePK} {} + + void writePK(IdType); + + WrSerializer& ser_; + PKGetterF getPK_; + + private: + const bool isCompositePK_; + }; + + class LoaderBase { + protected: + LoaderBase(VecDataGetterF&& getVectorData, bool isCompositePK) noexcept + : getVectorData_{std::move(getVectorData)}, isCompositePK_{isCompositePK} {} + + IdType readPKEncodedData(void* destBuf, Serializer& ser, std::string_view name, std::string_view idxType) { + VariantArray keys; + if (!isCompositePK_) { + keys.emplace_back(ser.GetVariant()); + } else { + const auto len = ser.GetVarUInt(); + if rx_unlikely (!len) { + throw Error(errLogic, "%s::LoadIndexCache:%s: serialized PK array is empty", idxType, name); + } + keys.reserve(len); + for (size_t i = 0; i < len; ++i) { + keys.emplace_back(ser.GetVariant()); + } + } + const IdType itemID = getVectorData_(keys, destBuf); + if rx_unlikely (itemID < 0) { + throw Error(errLogic, "%s::LoadIndexCache:%s: unable to find indexed item with requested PK", idxType, name); + } + return itemID; + } + + private: + VecDataGetterF getVectorData_; + const bool isCompositePK_; + }; + +public: + struct StorageCacheWriteResult { + Error err; + bool isCacheable = false; + }; + + FloatVectorIndex(const IndexDef&, PayloadType&&, FieldsSet&&); + void Delete(const VariantArray& keys, IdType, StringsHolder&, bool& clearCache) override final; + using Index::Delete; + [[noreturn]] SelectKeyResults SelectKey(const VariantArray&, CondType, SortType, SelectOpts, const BaseFunctionCtx::Ptr&, + const RdxContext&) override final; + void Upsert(VariantArray& result, const VariantArray& keys, IdType, bool& clearCache) override final; + Variant Upsert(const Variant& key, IdType id, bool& clearCache) override final; + SelectKeyResult Select(ConstFloatVectorView, const KnnSearchParams&, KnnCtx&) const; + void Commit() noexcept override final {} + void UpdateSortedIds(const UpdateSortedContext&) noexcept override final {} + const void* ColumnData() const noexcept override final { return nullptr; } + bool HoldsStrings() const noexcept override final { return false; } + void ReconfigureCache(const NamespaceCacheConfigData&) noexcept override final {} + IndexMemStat GetMemStat(const RdxContext&) noexcept override; + FloatVector GetFloatVector(IdType) const; + ConstFloatVectorView GetFloatVectorView(IdType) const; + [[nodiscard]] uint64_t GetHash(IdType rowId) const { return GetFloatVectorView(rowId).Hash(); } + [[nodiscard]] reindexer::FloatVectorDimension Dimension() const noexcept { + return reindexer::FloatVectorDimension(Opts().FloatVector().Dimension()); + } + RankedTypeQuery RankedType() const noexcept override final { return ToRankedTypeQuery(metric_); } + [[nodiscard]] reindexer::FloatVectorDimension FloatVectorDimension() const noexcept override final { return Dimension(); } + virtual StorageCacheWriteResult WriteIndexCache(WrSerializer&, PKGetterF&&, bool isCompositePK, + const std::atomic_int32_t& cancel) noexcept = 0; + virtual Error LoadIndexCache(std::string_view data, bool isCompositePK, VecDataGetterF&& getVecData) = 0; + virtual void RebuildCentroids(float /*dataPart*/) {} + +private: + virtual SelectKeyResult select(ConstFloatVectorView, const KnnSearchParams&, KnnCtx&) const = 0; + virtual Variant upsert(ConstFloatVectorView, IdType id, bool& clearCache) = 0; + + virtual FloatVector getFloatVector(IdType) const = 0; + virtual ConstFloatVectorView getFloatVectorView(IdType) const = 0; + + IndexMemStat memStat_; + tsl::hopscotch_sc_set emptyValues_; + +protected: + VectorMetric metric_; +}; + +} // namespace reindexer diff --git a/cpp_src/core/index/float_vector/hnsw_index.cc b/cpp_src/core/index/float_vector/hnsw_index.cc new file mode 100644 index 000000000..10f6afdd2 --- /dev/null +++ b/cpp_src/core/index/float_vector/hnsw_index.cc @@ -0,0 +1,416 @@ +#if RX_WITH_BUILTIN_ANN_INDEXES + +#include "hnsw_index.h" +#include "core/query/knn_search_params.h" +#include "core/selectfunc/ctx/knn_ctx.h" +#include "tools/logger.h" +#include "tools/normalize.h" + +namespace reindexer { + +static_assert(sizeof(IdType) == sizeof(hnswlib::labeltype), "Expecting 1-to-1 mapping"); + +static void PrintVecInstrcutionsLevel(std::string_view indexType, std::string_view name, VectorMetric metric, Index::CreationLog log) { + std::string vecInstructions = "disabled"; + switch (metric) { + case VectorMetric::L2: + if (vector_dists::L2WithAVX512()) { + vecInstructions = "avx512"; + } else if (vector_dists::L2WithAVX()) { + vecInstructions = "avx"; + } else if (vector_dists::L2WithSSE()) { + vecInstructions = "sse"; + } + break; + case VectorMetric::InnerProduct: + if (vector_dists::InnerProductWithAVX512()) { + vecInstructions = "avx512"; + } else if (vector_dists::InnerProductWithAVX()) { + vecInstructions = "avx"; + } else if (vector_dists::InnerProductWithSSE()) { + vecInstructions = "sse"; + } + break; + case VectorMetric::Cosine: + if (hnswlib::CosineWithAVX512()) { + vecInstructions = "avx512"; + } else if (hnswlib::CosineWithAVX()) { + vecInstructions = "avx"; + } else if (hnswlib::CosineWithSSE()) { + vecInstructions = "sse"; + } + break; + default: + throw Error(errLogic, "Attempt to construct %s index '%s' with unknow metric: %d", indexType, name, int(metric)); + } + if (log == Index::CreationLog::Yes) { + logFmt(LogInfo, "Creating {} index '{}'; Vector instructions level: {}", indexType, name, vecInstructions); + } +} + +constexpr static ReplaceDeleted kHNSWAllowReplaceDeleted = ReplaceDeleted_True; + +template <> +HnswIndexBase::HnswIndexBase(const IndexDef& idef, PayloadType&& payloadType, FieldsSet&& fields, + size_t currentNsSize, CreationLog log) + : Base{idef, std::move(payloadType), std::move(fields)}, + space_{newSpace(size_t(Dimension()), metric_)}, + map_{std::make_unique>( + space_.get(), std::max(idef.Opts().FloatVector().StartSize(), currentNsSize), idef.Opts().FloatVector().M(), + idef.Opts().FloatVector().EfConstruction())} { + map_->allow_replace_deleted_ = kHNSWAllowReplaceDeleted; + PrintVecInstrcutionsLevel("singlethread HNSW", idef.Name(), idef.Opts().FloatVector().Metric(), log); +} + +template <> +HnswIndexBase::HnswIndexBase(const IndexDef& idef, PayloadType&& payloadType, FieldsSet&& fields, + size_t currentNsSize, CreationLog log) + : Base{idef, std::move(payloadType), std::move(fields)}, + space_{newSpace(size_t(Dimension()), metric_)}, + map_{std::make_unique>( + space_.get(), std::max(idef.Opts().FloatVector().StartSize(), currentNsSize), idef.Opts().FloatVector().M(), + idef.Opts().FloatVector().EfConstruction())} { + map_->allow_replace_deleted_ = kHNSWAllowReplaceDeleted; + PrintVecInstrcutionsLevel("multithread HNSW", idef.Name(), idef.Opts().FloatVector().Metric(), log); +} + +template <> +HnswIndexBase::HnswIndexBase(const IndexDef& idef, PayloadType&& payloadType, FieldsSet&& fields, + size_t currentNsSize, CreationLog log) + : Base{idef, std::move(payloadType), std::move(fields)}, + space_{newSpace(size_t(Dimension()), metric_)}, + map_{std::make_unique>(space_.get(), + std::max(idef.Opts().FloatVector().StartSize(), currentNsSize))} { + PrintVecInstrcutionsLevel("bruteforce", idef.Name(), idef.Opts().FloatVector().Metric(), log); +} + +template