Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make region-client map concurrency safe #197

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions provider/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@ package provider
import (
"context"
"errors"
"net/http"

"github.com/aws/aws-sdk-go-v2/aws"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
log "github.com/sirupsen/logrus"
"net/http"
)

type providerAWS struct {
existsClient *s3.Client
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func (a *providerAWS) BucketExists(b *bucket.Bucket) (*bucket.Bucket, error) {
Expand Down Expand Up @@ -85,7 +87,8 @@ func NewProviderAWS() (*providerAWS, error) {
if usErr != nil {
return nil, usErr
}
pa.clients = map[string]*s3.Client{"us-east-1": usEastClient}
pa.clients = clientmap.New()
pa.clients.Set("us-east-1", usEastClient)
return pa, nil
}

Expand Down Expand Up @@ -142,8 +145,8 @@ func (a *providerAWS) newClient(region string) (*s3.Client, error) {

// TODO: This method is copied from providerLinode
func (a *providerAWS) getRegionClient(region string) (*s3.Client, error) {
c, ok := a.clients[region]
if ok {
c := a.clients.Get(region)
if c != nil {
return c, nil
}

Expand All @@ -152,6 +155,6 @@ func (a *providerAWS) getRegionClient(region string) (*s3.Client, error) {
if err != nil {
return nil, err
}
a.clients[region] = c // TODO: Make sure this is thread-safe
a.clients.Set(region, c)
return c, nil
}
54 changes: 54 additions & 0 deletions provider/clientmap/clientmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package clientmap

import (
"github.com/aws/aws-sdk-go-v2/service/s3"
"sync"
)

type ClientMap struct {
sync.Mutex
inner map[string]*s3.Client
}

func New() *ClientMap {
return &ClientMap{
Mutex: sync.Mutex{},
inner: make(map[string]*s3.Client),
}
}

func WithCapacity(cap int) *ClientMap {
return &ClientMap{
Mutex: sync.Mutex{},
inner: make(map[string]*s3.Client, cap),
}
}

func (m *ClientMap) Get(key string) *s3.Client {
m.Lock()
defer m.Unlock()
if v, ok := m.inner[key]; ok {
return v
}
return nil
}

func (m *ClientMap) Set(key string, value *s3.Client) {
m.Lock()
m.inner[key] = value
m.Unlock()
}

func (m *ClientMap) Len() int {
m.Lock()
defer m.Unlock()
return len(m.inner)
}

func (m *ClientMap) Each(fn func(region string, client *s3.Client)) {
m.Lock()
for region, client := range m.inner {
fn(region, client)
}
m.Unlock()
}
18 changes: 8 additions & 10 deletions provider/custom.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package provider
import (
"errors"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"strings"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type CustomProvider struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
insecure bool
addressStyle int
endpointFormat string
Expand Down Expand Up @@ -66,11 +68,7 @@ func (cp CustomProvider) Enumerate(b *bucket.Bucket) error {
}

func (cp *CustomProvider) getRegionClient(region string) *s3.Client {
c, ok := cp.clients[region]
if ok {
return c
}
return nil
return cp.clients.Get(region)
}

/*
Expand Down Expand Up @@ -98,15 +96,15 @@ func NewCustomProvider(addressStyle string, insecure bool, regions []string, end
return cp, nil
}

func (cp *CustomProvider) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(cp.regions))
func (cp *CustomProvider) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(cp.regions))
for _, r := range cp.regions {
regionUrl := strings.Replace(cp.endpointFormat, "$REGION", r, -1)
client, err := newNonAWSClient(cp, regionUrl)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
Expand Down
16 changes: 7 additions & 9 deletions provider/digitalocean.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package provider
import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type providerDO struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func (pdo providerDO) Insecure() bool {
Expand Down Expand Up @@ -66,25 +68,21 @@ func (pdo *providerDO) Regions() []string {
return urls
}

func (pdo *providerDO) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(pdo.regions))
func (pdo *providerDO) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(pdo.regions))
for _, r := range pdo.Regions() {
client, err := newNonAWSClient(pdo, r)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
}

func (pdo *providerDO) getRegionClient(region string) *s3.Client {
c, ok := pdo.clients[region]
if ok {
return c
}
return nil
return pdo.clients.Get(region)
}

func NewProviderDO() (*providerDO, error) {
Expand Down
16 changes: 7 additions & 9 deletions provider/dreamhost.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package provider
import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type ProviderDreamhost struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func (p ProviderDreamhost) Insecure() bool {
Expand Down Expand Up @@ -46,11 +48,7 @@ func (p ProviderDreamhost) Scan(bucket *bucket.Bucket, doDestructiveChecks bool)
}

func (p ProviderDreamhost) getRegionClient(region string) *s3.Client {
c, ok := p.clients[region]
if ok {
return c
}
return nil
return p.clients.Get(region)
}

func (p ProviderDreamhost) Enumerate(b *bucket.Bucket) error {
Expand All @@ -74,14 +72,14 @@ func (p ProviderDreamhost) Regions() []string {
return urls
}

func (p *ProviderDreamhost) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(p.regions))
func (p *ProviderDreamhost) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(p.regions))
for _, r := range p.Regions() {
client, err := newNonAWSClient(p, r)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
Expand Down
6 changes: 5 additions & 1 deletion provider/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package provider

import (
"errors"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

// GCP like AWS, has a "universal" endpoint, but unlike AWS GCP does not require you to follow a redirect to the
Expand All @@ -30,7 +32,9 @@ func (g GCP) BucketExists(b *bucket.Bucket) (*bucket.Bucket, error) {
if !bucket.IsValidS3BucketName(b.Name) {
return nil, errors.New("invalid bucket name")
}
exists, region, err := bucketExists(map[string]*s3.Client{"default": g.client}, b)
clients := clientmap.New()
clients.Set("default", g.client)
exists, region, err := bucketExists(clients, b)
if err != nil {
return b, err
}
Expand Down
16 changes: 7 additions & 9 deletions provider/linode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package provider
import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type providerLinode struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func NewProviderLinode() (*providerLinode, error) {
Expand All @@ -25,11 +27,7 @@ func NewProviderLinode() (*providerLinode, error) {
}

func (pl *providerLinode) getRegionClient(region string) *s3.Client {
c, ok := pl.clients[region]
if ok {
return c
}
return nil
return pl.clients.Get(region)
}

func (pl *providerLinode) BucketExists(b *bucket.Bucket) (*bucket.Bucket, error) {
Expand Down Expand Up @@ -61,14 +59,14 @@ func (pl *providerLinode) Enumerate(b *bucket.Bucket) error {
return nil
}

func (pl *providerLinode) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(pl.regions))
func (pl *providerLinode) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(pl.regions))
for _, r := range pl.Regions() {
client, err := newNonAWSClient(pl, r)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
Expand Down
16 changes: 9 additions & 7 deletions provider/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/config"
Expand All @@ -13,9 +16,8 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/permission"
"github.com/sa7mon/s3scanner/provider/clientmap"
log "github.com/sirupsen/logrus"
"net/http"
"time"
)

const (
Expand Down Expand Up @@ -195,13 +197,13 @@ func checkPermissions(client *s3.Client, b *bucket.Bucket, doDestructiveChecks b
return nil
}

func bucketExists(clients map[string]*s3.Client, b *bucket.Bucket) (bool, string, error) {
func bucketExists(clients *clientmap.ClientMap, b *bucket.Bucket) (bool, string, error) {
// TODO: Should this return a client or a region name? If region name, we'll need GetClient(region)
// TODO: Add region priority - order in which to check. maps are not ordered
results := make(chan bucketCheckResult, len(clients))
results := make(chan bucketCheckResult, clients.Len())
e := make(chan error, 1)

for region, client := range clients {
clients.Each(func(region string, client *s3.Client) {
go func(bucketName string, client *s3.Client, region string) {
logFields := log.Fields{
"bucket_name": b.Name,
Expand Down Expand Up @@ -234,9 +236,9 @@ func bucketExists(clients map[string]*s3.Client, b *bucket.Bucket) (bool, string
e <- err
}
}(b.Name, client, region)
}
})

for range clients {
for i := 0; i < clients.Len(); i++ {
select {
case err := <-e:
return false, "", err
Expand Down