From 0b4ed0161de75ae552f559a5010957a41805d278 Mon Sep 17 00:00:00 2001 From: Maxence Maireaux Date: Wed, 21 Jun 2023 15:20:11 +0200 Subject: [PATCH] fix: Use regex and cross segment (#448) --- pkg/storage/sqlstorage/accounts.go | 7 ++++++- pkg/storage/sqlstorage/balances.go | 15 ++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/pkg/storage/sqlstorage/accounts.go b/pkg/storage/sqlstorage/accounts.go index eb20f3a6a..212ecba84 100644 --- a/pkg/storage/sqlstorage/accounts.go +++ b/pkg/storage/sqlstorage/accounts.go @@ -39,7 +39,11 @@ func (s *Store) buildAccountsQuery(p ledger.AccountsQuery) (*sqlbuilder.SelectBu switch s.Schema().Flavor() { case sqlbuilder.PostgreSQL: src := strings.Split(address, ":") - sb.Where(fmt.Sprintf("jsonb_array_length(address_json) = %d", len(src))) + if address[len(address)-2:] != ".*" { + sb.Where(fmt.Sprintf("jsonb_array_length(address_json) = %d", len(src))) + } else { + src[len(src)-1] = src[len(src)-1][:len(src[len(src)-1])-2] + } for i, segment := range src { if segment == ".*" || segment == "*" || segment == "" { @@ -49,6 +53,7 @@ func (s *Store) buildAccountsQuery(p ledger.AccountsQuery) (*sqlbuilder.SelectBu operator := "==" if !accountNameRegex.MatchString(segment) { operator = "like_regex" + segment = strings.ReplaceAll(segment, "\\", "\\\\") } arg := sb.Args.Add(segment) diff --git a/pkg/storage/sqlstorage/balances.go b/pkg/storage/sqlstorage/balances.go index fb0848b93..d480abb59 100644 --- a/pkg/storage/sqlstorage/balances.go +++ b/pkg/storage/sqlstorage/balances.go @@ -25,7 +25,11 @@ func (s *Store) GetBalancesAggregated(ctx context.Context, q ledger.BalancesQuer switch s.Schema().Flavor() { case sqlbuilder.PostgreSQL: src := strings.Split(q.Filters.AddressRegexp, ":") - sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src))) + if q.Filters.AddressRegexp[len(q.Filters.AddressRegexp)-2:] != ".*" { + sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src))) + } else { + src[len(src)-1] = src[len(src)-1][:len(src[len(src)-1])-2] + } for i, segment := range src { if segment == ".*" || segment == "*" || segment == "" { @@ -35,6 +39,7 @@ func (s *Store) GetBalancesAggregated(ctx context.Context, q ledger.BalancesQuer operator := "==" if !accountNameRegex.MatchString(segment) { operator = "like_regex" + segment = strings.ReplaceAll(segment, "\\", "\\\\") } arg := sb.Args.Add(segment) @@ -115,14 +120,18 @@ func (s *Store) GetBalances(ctx context.Context, q ledger.BalancesQuery) (api.Cu switch s.Schema().Flavor() { case sqlbuilder.PostgreSQL: src := strings.Split(q.Filters.AddressRegexp, ":") - sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src))) + if q.Filters.AddressRegexp[len(q.Filters.AddressRegexp)-2:] != ".*" { + sb.Where(fmt.Sprintf("jsonb_array_length(account_json) = %d", len(src))) + } else { + src[len(src)-1] = src[len(src)-1][:len(src[len(src)-1])-2] + } for i, segment := range src { if segment == ".*" || segment == "*" || segment == "" { continue } - arg := sb.Args.Add(segment) + arg := sb.Args.Add(strings.ReplaceAll(segment, "\\", "\\\\")) sb.Where(fmt.Sprintf("account_json @@ ('$[%d] like_regex \"' || %s::text || '\"')::jsonpath", i, arg)) } case sqlbuilder.SQLite: