Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: split out main.go functionality #188

Merged
merged 5 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@ docker-image:
test:
go test ./...

test-coverage:
TEST_DB=1 TEST_MQ=1 go test ./... -coverprofile cover.out && go tool cover -html=cover.out

test-integration:
TEST_DB=1 TEST_MQ=1 go test ./...

upgrade:
go get -u ./...
go get -u ./...
80 changes: 80 additions & 0 deletions cmd/s3scanner/args.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package s3scanner

import (
"errors"
"fmt"
"github.com/spf13/viper"
)

type ArgCollection struct {
BucketFile string
BucketName string
DoEnumerate bool
Json bool
ProviderFlag string
Threads int
UseMq bool
Verbose bool
Version bool
WriteToDB bool
}

func (args ArgCollection) Validate() error {
// Validate: only 1 input flag is provided
numInputFlags := 0
if args.UseMq {
numInputFlags += 1
}
if args.BucketName != "" {
numInputFlags += 1
}
if args.BucketFile != "" {
numInputFlags += 1
}
if numInputFlags != 1 {
return errors.New("exactly one of: -bucket, -bucket-file, -mq required")
}

return nil
}

/*
validateConfig checks that the config file contains all necessary keys according to the args specified
*/
func validateConfig(args ArgCollection) error {
expectedKeys := []string{}
configFileRequired := false
if args.ProviderFlag == "custom" {
configFileRequired = true
expectedKeys = append(expectedKeys, []string{"providers.custom.insecure", "providers.custom.endpoint_format", "providers.custom.regions", "providers.custom.address_style"}...)
}
if args.WriteToDB {
configFileRequired = true
expectedKeys = append(expectedKeys, []string{"db.uri"}...)
}
if args.UseMq {
configFileRequired = true
expectedKeys = append(expectedKeys, []string{"mq.queue_name", "mq.uri"}...)
}
// User didn't give any arguments that require the config file
if !configFileRequired {
return nil
}

// Try to find and read config file
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
return errors.New("config file not found")
} else {
return err
}
}

// Verify all expected keys are in the config file
for _, k := range expectedKeys {
if !viper.IsSet(k) {
return fmt.Errorf("config file missing key: %s", k)
}
}
return nil
}
87 changes: 87 additions & 0 deletions cmd/s3scanner/args_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package s3scanner

import (
"errors"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"testing"
)

func TestArgCollection_Validate(t *testing.T) {
goodInputs := []ArgCollection{
{
BucketName: "asdf",
BucketFile: "",
UseMq: false,
},
{
BucketName: "",
BucketFile: "buckets.txt",
UseMq: false,
},
{
BucketName: "",
BucketFile: "",
UseMq: true,
},
}
tooManyInputs := []ArgCollection{
{
BucketName: "asdf",
BucketFile: "asdf",
UseMq: false,
},
{
BucketName: "adsf",
BucketFile: "",
UseMq: true,
},
{
BucketName: "",
BucketFile: "asdf.txt",
UseMq: true,
},
}

for _, v := range goodInputs {
err := v.Validate()
if err != nil {
t.Errorf("%v: %e", v, err)
}
}
for _, v := range tooManyInputs {
err := v.Validate()
if err == nil {
t.Errorf("expected error but did not find one: %v", v)
}
}
}

func TestValidateConfig(t *testing.T) {
a := ArgCollection{
DoEnumerate: false,
Json: false,
ProviderFlag: "custom",
UseMq: true,
WriteToDB: true,
}
viper.AddConfigPath("../../")
viper.SetConfigName("config") // name of config file (without extension)
viper.SetConfigType("yml") // REQUIRED if the config file does not have the extension in the name
err := validateConfig(a)
assert.Nil(t, err)
}

func TestValidateConfig_NotFound(t *testing.T) {
a := ArgCollection{
DoEnumerate: false,
Json: false,
ProviderFlag: "custom",
UseMq: true,
WriteToDB: true,
}
viper.SetConfigName("asdf") // won't be found
viper.SetConfigType("yml")
err := validateConfig(a)
assert.Equal(t, errors.New("config file not found"), err)
}
Loading