Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ESD-25205: Fix concurrency issue in association resources #425

Merged
merged 4 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions internal/mutex/mutex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package mutex

import (
"log"
"sync"
)

// KeyValue is a simple key/value
// store for arbitrary mutexes.
type KeyValue struct {
lock sync.Mutex
store map[string]*sync.Mutex
}

// Lock the mutex for the given key.
func (m *KeyValue) Lock(key string) {
log.Printf("[DEBUG] Locking mutex for key: %q", key)
defer log.Printf("[DEBUG] Locked mutex for key: %q", key)

m.get(key).Lock()
}

// Unlock the mutex for the given key.
func (m *KeyValue) Unlock(key string) {
log.Printf("[DEBUG] Unlocking mutex for key: %q", key)
defer log.Printf("[DEBUG] Unlocked mutex for key: %q", key)

m.get(key).Unlock()
}

// Returns a mutex for the given key.
func (m *KeyValue) get(key string) *sync.Mutex {
m.lock.Lock()
defer m.lock.Unlock()

mutex, ok := m.store[key]
if !ok {
mutex = &sync.Mutex{}
m.store[key] = mutex
}

return mutex
}

// New returns a properly initialized KeyValue mutex.
func New() *KeyValue {
return &KeyValue{
store: make(map[string]*sync.Mutex),
}
}
61 changes: 61 additions & 0 deletions internal/mutex/mutex_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package mutex

import (
"testing"
"time"
)

func TestKeyValueLock(t *testing.T) {
keyValueMutex := New()
keyValueMutex.Lock("foo")

doneChannel := make(chan struct{})
go func() {
keyValueMutex.Lock("foo")
close(doneChannel)
}()

select {
case <-doneChannel:
t.Fatal("Second lock was able to be taken. This shouldn't happen.")
case <-time.After(50 * time.Millisecond):
// Test passing.
}
}

func TestKeyValueUnlock(t *testing.T) {
keyValueMutex := New()
keyValueMutex.Lock("foo")
keyValueMutex.Unlock("foo")

doneChannel := make(chan struct{})
go func() {
keyValueMutex.Lock("foo")
close(doneChannel)
}()

select {
case <-doneChannel:
// Test passing.
case <-time.After(50 * time.Millisecond):
t.Fatal("Second lock blocked after unlock. This shouldn't happen.")
}
}

func TestKeyValueDifferentKeys(t *testing.T) {
keyValueMutex := New()
keyValueMutex.Lock("foo")

doneChannel := make(chan struct{})
go func() {
keyValueMutex.Lock("bar")
close(doneChannel)
}()

select {
case <-doneChannel:
// Test passing.
case <-time.After(50 * time.Millisecond):
t.Fatal("Second lock on a different key blocked. This shouldn't happen.")
}
}
4 changes: 4 additions & 0 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/meta"

"github.com/auth0/terraform-provider-auth0/internal/mutex"
)

var version = "dev"

var globalMutex = mutex.New()

// New returns a *schema.Provider.
func New() *schema.Provider {
provider := &schema.Provider{
Expand Down
16 changes: 8 additions & 8 deletions internal/provider/resource_auth0_connection_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ func createConnectionClient(ctx context.Context, data *schema.ResourceData, meta
api := meta.(*management.Management)

connectionID := data.Get("connection_id").(string)

globalMutex.Lock(connectionID)
defer globalMutex.Unlock(connectionID)

connection, err := api.Connection.Read(connectionID)
if err != nil {
if mErr, ok := err.(management.Error); ok && mErr.Status() == http.StatusNotFound {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

404 checks should not happen on CREATE to prevent Root resource was present, but now absent errors.

data.SetId("")
return nil
}
return diag.FromErr(err)
}

Expand All @@ -103,10 +103,6 @@ func createConnectionClient(ctx context.Context, data *schema.ResourceData, meta
connectionID,
&management.Connection{EnabledClients: &enabledClients},
); err != nil {
if mErr, ok := err.(management.Error); ok && mErr.Status() == http.StatusNotFound {
data.SetId("")
return nil
}
return diag.FromErr(err)
}

Expand Down Expand Up @@ -153,6 +149,10 @@ func deleteConnectionClient(_ context.Context, data *schema.ResourceData, meta i
api := meta.(*management.Management)

connectionID := data.Get("connection_id").(string)

globalMutex.Lock(connectionID)
defer globalMutex.Unlock(connectionID)

connection, err := api.Connection.Read(connectionID)
if err != nil {
if mErr, ok := err.(management.Error); ok && mErr.Status() == http.StatusNotFound {
Expand Down
16 changes: 16 additions & 0 deletions internal/provider/resource_auth0_organization_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ func createOrganizationConnection(ctx context.Context, data *schema.ResourceData
api := meta.(*management.Management)

organizationID := data.Get("organization_id").(string)

globalMutex.Lock(organizationID)
defer globalMutex.Unlock(organizationID)

connectionID := data.Get("connection_id").(string)
assignMembershipOnLogin := data.Get("assign_membership_on_login").(bool)

Expand All @@ -119,6 +123,10 @@ func readOrganizationConnection(ctx context.Context, data *schema.ResourceData,

organizationConnection, err := api.Organization.Connection(organizationID, connectionID)
if err != nil {
if err, ok := err.(management.Error); ok && err.Status() == http.StatusNotFound {
data.SetId("")
return nil
}
return diag.FromErr(err)
}

Expand All @@ -135,6 +143,10 @@ func updateOrganizationConnection(ctx context.Context, data *schema.ResourceData
api := meta.(*management.Management)

organizationID := data.Get("organization_id").(string)

globalMutex.Lock(organizationID)
defer globalMutex.Unlock(organizationID)

connectionID := data.Get("connection_id").(string)
assignMembershipOnLogin := data.Get("assign_membership_on_login").(bool)

Expand All @@ -153,6 +165,10 @@ func deleteOrganizationConnection(ctx context.Context, data *schema.ResourceData
api := meta.(*management.Management)

organizationID := data.Get("organization_id").(string)

globalMutex.Lock(organizationID)
defer globalMutex.Unlock(organizationID)

connectionID := data.Get("connection_id").(string)

if err := api.Organization.DeleteConnection(organizationID, connectionID); err != nil {
Expand Down
19 changes: 17 additions & 2 deletions internal/provider/resource_auth0_organization_member.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ func createOrganizationMember(ctx context.Context, d *schema.ResourceData, m int
userID := d.Get("user_id").(string)
orgID := d.Get("organization_id").(string)

globalMutex.Lock(orgID)
defer globalMutex.Unlock(orgID)

if err := api.Organization.AddMembers(orgID, []string{userID}); err != nil {
return diag.FromErr(err)
}
Expand All @@ -104,8 +107,8 @@ func assignRoles(d *schema.ResourceData, api *management.Management) error {
return nil
}

orgID := d.Get("organization_id").(string)
userID := d.Get("user_id").(string)
orgID := d.Get("organization_id").(string)

toAdd, toRemove := value.Difference(d, "roles")

Expand Down Expand Up @@ -159,6 +162,10 @@ func readOrganizationMember(ctx context.Context, d *schema.ResourceData, m inter

roles, err := api.Organization.MemberRoles(orgID, userID)
if err != nil {
if err, ok := err.(management.Error); ok && err.Status() == http.StatusNotFound {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 404 check was missed on this read func.

d.SetId("")
return nil
}
return diag.FromErr(err)
}

Expand All @@ -175,6 +182,11 @@ func readOrganizationMember(ctx context.Context, d *schema.ResourceData, m inter
func updateOrganizationMember(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics {
api := m.(*management.Management)

orgID := d.Get("organization_id").(string)

globalMutex.Lock(orgID)
defer globalMutex.Unlock(orgID)

if err := assignRoles(d, api); err != nil {
return diag.FromErr(fmt.Errorf("failed to assign members to organization. %w", err))
}
Expand All @@ -185,8 +197,11 @@ func updateOrganizationMember(ctx context.Context, d *schema.ResourceData, m int
func deleteOrganizationMember(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics {
api := m.(*management.Management)

orgID := d.Get("organization_id").(string)
userID := d.Get("user_id").(string)
orgID := d.Get("organization_id").(string)

globalMutex.Lock(orgID)
defer globalMutex.Unlock(orgID)

if err := api.Organization.DeleteMember(orgID, []string{userID}); err != nil {
if err, ok := err.(management.Error); ok && err.Status() == http.StatusNotFound {
Expand Down