Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Clean collection utils #3028

Merged
merged 3 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/acceptance/helpers/parameter_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (c *ParameterClient) UnsetAccountParameter(t *testing.T, parameter sdk.Acco

func FindParameter[T ~string](t *testing.T, parameters []*sdk.Parameter, parameter T) *sdk.Parameter {
t.Helper()
param, err := collections.FindOne(parameters, func(p *sdk.Parameter) bool { return p.Key == string(parameter) })
param, err := collections.FindFirst(parameters, func(p *sdk.Parameter) bool { return p.Key == string(parameter) })
require.NoError(t, err)
return *param
}
3 changes: 1 addition & 2 deletions pkg/internal/collections/collection_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ import (

var ErrObjectNotFound = errors.New("object does not exist")

// TODO [SNOW-1473414]: move collection helpers fully with a separate PR
func FindOne[T any](collection []T, condition func(T) bool) (*T, error) {
func FindFirst[T any](collection []T, condition func(T) bool) (*T, error) {
for _, o := range collection {
if condition(o) {
return &o, nil
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
package collections

import (
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestMap(t *testing.T) {
func Test_FindFirst(t *testing.T) {
stringSlice := []string{"1", "22", "333", "334"}

t.Run("basic find", func(t *testing.T) {
result, resultErr := FindFirst(stringSlice, func(s string) bool { return s == "22" })

require.Equal(t, "22", *result)
require.Nil(t, resultErr)
})

t.Run("two matching, first returned", func(t *testing.T) {
result, resultErr := FindFirst(stringSlice, func(s string) bool { return strings.HasPrefix(s, "33") })

require.Equal(t, "333", *result)
require.Nil(t, resultErr)
})

t.Run("no item", func(t *testing.T) {
result, resultErr := FindFirst(stringSlice, func(s string) bool { return s == "4444" })

require.Nil(t, result)
require.ErrorIs(t, resultErr, ErrObjectNotFound)
})
}

func Test_Map(t *testing.T) {
t.Run("basic mapping", func(t *testing.T) {
stringSlice := []string{"1", "22", "333"}
stringLenSlice := Map(stringSlice, func(s string) int { return len(s) })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestQueue(t *testing.T) {
func Test_Queue(t *testing.T) {
t.Run("empty queue initialization", func(t *testing.T) {
queue := NewQueue[int]()

Expand Down
20 changes: 10 additions & 10 deletions pkg/resources/api_authentication_integration_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func handleApiAuthImport(d *schema.ResourceData, integration *sdk.SecurityIntegr
return err
}

oauthAccessTokenValidity, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthAccessTokenValidity, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_ACCESS_TOKEN_VALIDITY"
})
if err == nil {
Expand All @@ -235,7 +235,7 @@ func handleApiAuthImport(d *schema.ResourceData, integration *sdk.SecurityIntegr
return err
}
}
oauthRefreshTokenValidity, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthRefreshTokenValidity, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_REFRESH_TOKEN_VALIDITY"
})
if err == nil {
Expand All @@ -247,21 +247,21 @@ func handleApiAuthImport(d *schema.ResourceData, integration *sdk.SecurityIntegr
return err
}
}
oauthClientId, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_CLIENT_ID" })
oauthClientId, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_CLIENT_ID" })
if err == nil {
if err = d.Set("oauth_client_id", oauthClientId.Value); err != nil {
return err
}
}
oauthClientAuthMethod, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthClientAuthMethod, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_CLIENT_AUTH_METHOD"
})
if err == nil {
if err = d.Set("oauth_client_auth_method", oauthClientAuthMethod.Value); err != nil {
return err
}
}
oauthTokenEndpoint, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_TOKEN_ENDPOINT" })
oauthTokenEndpoint, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_TOKEN_ENDPOINT" })
if err == nil {
if err = d.Set("oauth_token_endpoint", oauthTokenEndpoint.Value); err != nil {
return err
Expand All @@ -288,33 +288,33 @@ func handleApiAuthRead(d *schema.ResourceData,
return err
}
if withExternalChangesMarking {
oauthAccessTokenValidity, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthAccessTokenValidity, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_ACCESS_TOKEN_VALIDITY"
})
if err != nil {
return err
}

oauthRefreshTokenValidity, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthRefreshTokenValidity, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_REFRESH_TOKEN_VALIDITY"
})
if err != nil {
return err
}

oauthClientId, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_CLIENT_ID" })
oauthClientId, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_CLIENT_ID" })
if err != nil {
return err
}

oauthClientAuthMethod, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthClientAuthMethod, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_CLIENT_AUTH_METHOD"
})
if err != nil {
return err
}

oauthTokenEndpoint, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_TOKEN_ENDPOINT" })
oauthTokenEndpoint, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_TOKEN_ENDPOINT" })
if err != nil {
return err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ func ImportApiAuthenticationWithAuthorizationCodeGrant(ctx context.Context, d *s
if err := handleApiAuthImport(d, integration, properties); err != nil {
return nil, err
}
oauthAuthorizationEndpoint, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthAuthorizationEndpoint, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_AUTHORIZATION_ENDPOINT"
})
if err == nil {
if err = d.Set("oauth_authorization_endpoint", oauthAuthorizationEndpoint.Value); err != nil {
return nil, err
}
}
oauthAllowedScopes, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
oauthAllowedScopes, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
if err == nil {
if err = d.Set("oauth_allowed_scopes", sdk.ParseCommaSeparatedStringArray(oauthAllowedScopes.Value, false)); err != nil {
return nil, err
Expand Down Expand Up @@ -168,14 +168,14 @@ func ReadContextApiAuthenticationIntegrationWithAuthorizationCodeGrant(withExter
if c := integration.Category; c != sdk.SecurityIntegrationCategory {
return diag.FromErr(fmt.Errorf("expected %v to be a %s integration, got %v", id, sdk.SecurityIntegrationCategory, c))
}
oauthAuthorizationEndpoint, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthAuthorizationEndpoint, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_AUTHORIZATION_ENDPOINT"
})
if err != nil {
return diag.FromErr(err)
}

oauthAllowedScopes, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
oauthAllowedScopes, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
if err != nil {
return diag.FromErr(err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func ImportApiAuthenticationWithClientCredentials(ctx context.Context, d *schema
if err := handleApiAuthImport(d, integration, properties); err != nil {
return nil, err
}
oauthAllowedScopes, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
oauthAllowedScopes, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
if err == nil {
if err = d.Set("oauth_allowed_scopes", sdk.ParseCommaSeparatedStringArray(oauthAllowedScopes.Value, false)); err != nil {
return nil, err
Expand Down Expand Up @@ -148,7 +148,7 @@ func ReadContextApiAuthenticationIntegrationWithClientCredentials(withExternalCh
if c := integration.Category; c != sdk.SecurityIntegrationCategory {
return diag.FromErr(fmt.Errorf("expected %v to be a %s integration, got %v", id, sdk.SecurityIntegrationCategory, c))
}
oauthAllowedScopes, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
oauthAllowedScopes, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ALLOWED_SCOPES" })
if err != nil {
return diag.FromErr(err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ func ImportApiAuthenticationWithJwtBearer(ctx context.Context, d *schema.Resourc
if err := handleApiAuthImport(d, integration, properties); err != nil {
return nil, err
}
oauthAuthorizationEndpoint, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthAuthorizationEndpoint, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_AUTHORIZATION_ENDPOINT"
})
if err == nil {
if err = d.Set("oauth_authorization_endpoint", oauthAuthorizationEndpoint.Value); err != nil {
return nil, err
}
}
oauthAssertionIssuer, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ASSERTION_ISSUER" })
oauthAssertionIssuer, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ASSERTION_ISSUER" })
if err == nil {
if err = d.Set("oauth_assertion_issuer", oauthAssertionIssuer.Value); err != nil {
return nil, err
Expand Down Expand Up @@ -156,14 +156,14 @@ func ReadContextApiAuthenticationIntegrationWithJwtBearer(withExternalChangesMar
if c := integration.Category; c != sdk.SecurityIntegrationCategory {
return diag.FromErr(fmt.Errorf("expected %v to be a %s integration, got %v", id, sdk.SecurityIntegrationCategory, c))
}
oauthAuthorizationEndpoint, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool {
oauthAuthorizationEndpoint, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool {
return property.Name == "OAUTH_AUTHORIZATION_ENDPOINT"
})
if err != nil {
return diag.FromErr(err)
}

oauthAssertionIssuer, err := collections.FindOne(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ASSERTION_ISSUER" })
oauthAssertionIssuer, err := collections.FindFirst(properties, func(property sdk.SecurityIntegrationProperty) bool { return property.Name == "OAUTH_ASSERTION_ISSUER" })
if err != nil {
return diag.FromErr(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/resources/custom_diffs.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func BoolParameterValueComputedIf[T ~string](key string, params []*sdk.Parameter

func ParameterValueComputedIf[T ~string](key string, parameters []*sdk.Parameter, objectParameterLevel sdk.ParameterType, param T, valueToString func(v any) string) schema.CustomizeDiffFunc {
return func(ctx context.Context, d *schema.ResourceDiff, meta any) error {
foundParameter, err := collections.FindOne(parameters, func(parameter *sdk.Parameter) bool { return parameter.Key == string(param) })
foundParameter, err := collections.FindFirst(parameters, func(parameter *sdk.Parameter) bool { return parameter.Key == string(param) })
if err != nil {
log.Printf("[WARN] failed to find parameter: %s", param)
return nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/resources/database_state_upgraders.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func v092DatabaseStateUpgrader(ctx context.Context, rawState map[string]any, met
}

for i, accountLocator := range accountLocators {
replicationAccount, err := collections.FindOne(replicationAccounts, func(account *sdk.ReplicationAccount) bool {
replicationAccount, err := collections.FindFirst(replicationAccounts, func(account *sdk.ReplicationAccount) bool {
return account.AccountLocator == accountLocator
})
if err != nil {
Expand Down
Loading
Loading