diff --git a/docs/sources/configuration/_index.md b/docs/sources/configuration/_index.md index 5f80a6f90a30a..2a9ab50c06de7 100644 --- a/docs/sources/configuration/_index.md +++ b/docs/sources/configuration/_index.md @@ -3721,6 +3721,13 @@ The `azure_storage_config` block configures the connection to Azure object stora # CLI flag: -.azure.use-managed-identity [use_managed_identity: | default = false] +# Use a Federated Token to authenticate to the Azure storage account. +# Enable if you want to use Azure Workload Identity. Expects AZURE_CLIENT_ID, +# AZURE_TENANT_ID and AZURE_FEDERATED_TOKEN_FILE envs to be present (set automatically +# when using Azure Workload Identity). +# CLI flag: -.azure.use-federated-token +[use_federated_token: | default = false] + # User assigned identity ID to authenticate to the Azure storage account. # CLI flag: -.azure.user-assigned-id [user_assigned_id: | default = ""] diff --git a/docs/sources/installation/helm/reference.md b/docs/sources/installation/helm/reference.md index 8ecc2eea41a0a..55a259111e5a1 100644 --- a/docs/sources/installation/helm/reference.md +++ b/docs/sources/installation/helm/reference.md @@ -1795,6 +1795,7 @@ null "accountKey": null, "accountName": null, "requestTimeout": null, + "useFederatedToken": false, "useManagedIdentity": false, "userAssignedId": null }, diff --git a/go.mod b/go.mod index e609ff36e1cd6..e7e3f40cad8cd 100644 --- a/go.mod +++ b/go.mod @@ -113,6 +113,7 @@ require ( ) require ( + github.com/Azure/go-autorest/autorest v0.11.28 github.com/fsnotify/fsnotify v1.6.0 github.com/heroku/x v0.0.50 github.com/prometheus/alertmanager v0.25.0 @@ -133,7 +134,6 @@ require ( github.com/Azure/azure-sdk-for-go v65.0.0+incompatible // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-autorest v14.2.0+incompatible // indirect - github.com/Azure/go-autorest/autorest v0.11.28 // indirect github.com/Azure/go-autorest/autorest/azure/cli v0.4.5 // indirect github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect github.com/Azure/go-autorest/autorest/to v0.4.0 // indirect diff --git a/pkg/storage/chunk/client/azure/blob_storage_client.go b/pkg/storage/chunk/client/azure/blob_storage_client.go index 1585de49dae5f..1ab87a631c379 100644 --- a/pkg/storage/chunk/client/azure/blob_storage_client.go +++ b/pkg/storage/chunk/client/azure/blob_storage_client.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "os" "strings" "sync" "time" @@ -16,6 +17,7 @@ import ( "github.com/Azure/azure-pipeline-go/pipeline" "github.com/Azure/azure-storage-blob-go/azblob" "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/azure/auth" "github.com/grafana/dskit/flagext" "github.com/mattn/go-ieproxy" @@ -32,6 +34,7 @@ import ( const ( // Environment azureGlobal = "AzureGlobal" + azurePublicCloud = "AzurePublicCloud" azureChinaCloud = "AzureChinaCloud" azureGermanCloud = "AzureGermanCloud" azureUSGovernment = "AzureUSGovernment" @@ -48,6 +51,11 @@ var ( azureUSGovernment: "blob.core.usgovcloudapi.net", } + defaultAuthFunctions = authFunctions{ + NewOAuthConfigFunc: adal.NewOAuthConfig, + NewServicePrincipalTokenFromFederatedTokenFunc: adal.NewServicePrincipalTokenFromFederatedToken, + } + // default Azure http client. defaultClientFactory = func() *http.Client { return &http.Client{ @@ -79,6 +87,7 @@ type BlobStorageConfig struct { ContainerName string `yaml:"container_name"` Endpoint string `yaml:"endpoint_suffix"` UseManagedIdentity bool `yaml:"use_managed_identity"` + UseFederatedToken bool `yaml:"use_federated_token"` UserAssignedID string `yaml:"user_assigned_id"` UseServicePrincipal bool `yaml:"use_service_principal"` ClientID string `yaml:"client_id"` @@ -94,6 +103,11 @@ type BlobStorageConfig struct { MaxRetryDelay time.Duration `yaml:"max_retry_delay"` } +type authFunctions struct { + NewOAuthConfigFunc func(activeDirectoryEndpoint, tenantID string) (*adal.OAuthConfig, error) + NewServicePrincipalTokenFromFederatedTokenFunc func(oauthConfig adal.OAuthConfig, clientID string, jwt string, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) +} + // RegisterFlags adds the flags required to config this to the given FlagSet func (c *BlobStorageConfig) RegisterFlags(f *flag.FlagSet) { c.RegisterFlagsWithPrefix("", f) @@ -107,6 +121,7 @@ func (c *BlobStorageConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagS f.StringVar(&c.ContainerName, prefix+"azure.container-name", "loki", "Name of the storage account blob container used to store chunks. This container must be created before running cortex.") f.StringVar(&c.Endpoint, prefix+"azure.endpoint-suffix", "", "Azure storage endpoint suffix without schema. The storage account name will be prefixed to this value to create the FQDN.") f.BoolVar(&c.UseManagedIdentity, prefix+"azure.use-managed-identity", false, "Use Managed Identity to authenticate to the Azure storage account.") + f.BoolVar(&c.UseFederatedToken, prefix+"azure.use-federated-token", false, "Use Federated Token to authenticate to the Azure storage account.") f.StringVar(&c.UserAssignedID, prefix+"azure.user-assigned-id", "", "User assigned identity ID to authenticate to the Azure storage account.") f.StringVar(&c.ChunkDelimiter, prefix+"azure.chunk-delimiter", "-", "Chunk delimiter for blob ID to be used") f.DurationVar(&c.RequestTimeout, prefix+"azure.request-timeout", 30*time.Second, "Timeout for requests made against azure blob storage.") @@ -316,7 +331,7 @@ func (b *BlobStorage) newPipeline(hedgingCfg hedging.Config, hedging bool) (pipe }) } - if !b.cfg.UseManagedIdentity && !b.cfg.UseServicePrincipal && b.cfg.UserAssignedID == "" { + if !b.cfg.UseFederatedToken && !b.cfg.UseManagedIdentity && !b.cfg.UseServicePrincipal && b.cfg.UserAssignedID == "" { credential, err := azblob.NewSharedKeyCredential(b.cfg.StorageAccountName, b.cfg.StorageAccountKey.String()) if err != nil { return nil, err @@ -341,7 +356,7 @@ func (b *BlobStorage) getOAuthToken() (azblob.TokenCredential, error) { if b.tc != nil { return b.tc, nil } - spt, err := b.getServicePrincipalToken() + spt, err := b.getServicePrincipalToken(defaultAuthFunctions) if err != nil { return nil, err } @@ -368,7 +383,7 @@ func (b *BlobStorage) getOAuthToken() (azblob.TokenCredential, error) { return b.tc, nil } -func (b *BlobStorage) getServicePrincipalToken() (*adal.ServicePrincipalToken, error) { +func (b *BlobStorage) getServicePrincipalToken(authFunctions authFunctions) (*adal.ServicePrincipalToken, error) { var endpoint string if b.cfg.Endpoint != "" { endpoint = b.cfg.Endpoint @@ -378,6 +393,28 @@ func (b *BlobStorage) getServicePrincipalToken() (*adal.ServicePrincipalToken, e resource := fmt.Sprintf("https://%s.%s", b.cfg.StorageAccountName, endpoint) + if b.cfg.UseFederatedToken { + token, err := b.servicePrincipalTokenFromFederatedToken(resource, authFunctions.NewOAuthConfigFunc, authFunctions.NewServicePrincipalTokenFromFederatedTokenFunc) + var customRefreshFunc adal.TokenRefresh = func(context context.Context, resource string) (*adal.Token, error) { + newToken, err := b.servicePrincipalTokenFromFederatedToken(resource, authFunctions.NewOAuthConfigFunc, authFunctions.NewServicePrincipalTokenFromFederatedTokenFunc) + if err != nil { + return nil, err + } + + err = newToken.Refresh() + if err != nil { + return nil, err + } + + token := newToken.Token() + + return &token, nil + } + + token.SetCustomRefreshFunc(customRefreshFunc) + return token, err + } + if b.cfg.UseServicePrincipal { config := auth.NewClientCredentialsConfig(b.cfg.ClientID, b.cfg.ClientSecret.String(), b.cfg.TenantID) config.Resource = resource @@ -395,6 +432,35 @@ func (b *BlobStorage) getServicePrincipalToken() (*adal.ServicePrincipalToken, e return msiConfig.ServicePrincipalToken() } +func (b *BlobStorage) servicePrincipalTokenFromFederatedToken(resource string, newOAuthConfigFunc func(activeDirectoryEndpoint, tenantID string) (*adal.OAuthConfig, error), newServicePrincipalTokenFromFederatedTokenFunc func(oauthConfig adal.OAuthConfig, clientID string, jwt string, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error)) (*adal.ServicePrincipalToken, error) { + environmentName := azurePublicCloud + if b.cfg.Environment != azureGlobal { + environmentName = b.cfg.Environment + } + + env, err := azure.EnvironmentFromName(environmentName) + if err != nil { + return nil, err + } + + azClientID := os.Getenv("AZURE_CLIENT_ID") + azTenantID := os.Getenv("AZURE_TENANT_ID") + + jwtBytes, err := os.ReadFile(os.Getenv("AZURE_FEDERATED_TOKEN_FILE")) + if err != nil { + return nil, err + } + + jwt := string(jwtBytes) + + oauthConfig, err := newOAuthConfigFunc(env.ActiveDirectoryEndpoint, azTenantID) + if err != nil { + return nil, err + } + + return newServicePrincipalTokenFromFederatedTokenFunc(*oauthConfig, azClientID, jwt, resource) +} + // List implements chunk.ObjectClient. func (b *BlobStorage) List(ctx context.Context, prefix, delimiter string) ([]client.StorageObject, []client.StorageCommonPrefix, error) { var storageObjects []client.StorageObject diff --git a/pkg/storage/chunk/client/azure/blob_storage_client_test.go b/pkg/storage/chunk/client/azure/blob_storage_client_test.go index 79cd7c6867822..f73cc399953e0 100644 --- a/pkg/storage/chunk/client/azure/blob_storage_client_test.go +++ b/pkg/storage/chunk/client/azure/blob_storage_client_test.go @@ -5,12 +5,16 @@ import ( "context" "net/http" "net/url" + "os" "strings" "testing" "time" + "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/go-autorest/autorest/azure" "github.com/grafana/dskit/flagext" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "go.uber.org/atomic" "github.com/grafana/loki/pkg/storage/chunk/client/hedging" @@ -24,6 +28,58 @@ func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) return fn(req) } +type FederatedTokenTestSuite struct { + suite.Suite + config *BlobStorage + mockOAuthConfig *adal.OAuthConfig + mockedServicePrincipalToken *adal.ServicePrincipalToken +} + +func (suite *FederatedTokenTestSuite) SetupTest() { + suite.mockOAuthConfig, _ = adal.NewOAuthConfig("foo", "bar") + suite.mockedServicePrincipalToken = new(adal.ServicePrincipalToken) + suite.config = &BlobStorage{ + cfg: &BlobStorageConfig{ + ContainerName: "foo", + StorageAccountName: "bar", + Environment: azureGlobal, + UseFederatedToken: true, + }, + } + + suite.T().Setenv("AZURE_CLIENT_ID", "myClientId") + suite.T().Setenv("AZURE_TENANT_ID", "myTenantId") + + tmpDir := suite.T().TempDir() + _ = os.WriteFile(tmpDir+"/jwtToken", []byte("myJwtToken"), 0666) + suite.T().Setenv("AZURE_FEDERATED_TOKEN_FILE", tmpDir+"/jwtToken") +} + +func (suite *FederatedTokenTestSuite) TestGetServicePrincipalToken() { + newOAuthConfigFunc := func(activeDirectoryEndpoint, tenantID string) (*adal.OAuthConfig, error) { + require.Equal(suite.T(), azure.PublicCloud.ActiveDirectoryEndpoint, activeDirectoryEndpoint) + require.Equal(suite.T(), "myTenantId", tenantID) + + _, err := adal.NewOAuthConfig(activeDirectoryEndpoint, tenantID) + require.NoError(suite.T(), err) + + return suite.mockOAuthConfig, nil + } + + servicePrincipalTokenFromFederatedTokenFunc := func(oauthConfig adal.OAuthConfig, clientID string, jwt string, resource string, callbacks ...adal.TokenRefreshCallback) (*adal.ServicePrincipalToken, error) { + require.True(suite.T(), *suite.mockOAuthConfig == oauthConfig, "should return the mocked object") + require.Equal(suite.T(), "myClientId", clientID) + require.Equal(suite.T(), "myJwtToken", jwt) + require.Equal(suite.T(), "https://bar.blob.core.windows.net", resource) + return suite.mockedServicePrincipalToken, nil + } + + token, err := suite.config.getServicePrincipalToken(authFunctions{newOAuthConfigFunc, servicePrincipalTokenFromFederatedTokenFunc}) + + require.NoError(suite.T(), err) + require.True(suite.T(), suite.mockedServicePrincipalToken == token, "should return the mocked object") +} + func Test_Hedging(t *testing.T) { for _, tc := range []struct { name string @@ -131,6 +187,10 @@ func Test_DefaultBlobURL(t *testing.T) { require.Equal(t, *expect, bloburl.URL()) } +func Test_UseFederatedToken(t *testing.T) { + suite.Run(t, new(FederatedTokenTestSuite)) +} + func Test_EndpointSuffixWithBlob(t *testing.T) { c, err := NewBlobStorage(&BlobStorageConfig{ ContainerName: "foo", diff --git a/production/helm/loki/templates/_helpers.tpl b/production/helm/loki/templates/_helpers.tpl index 1c84850f0db9b..a093c58deee1f 100644 --- a/production/helm/loki/templates/_helpers.tpl +++ b/production/helm/loki/templates/_helpers.tpl @@ -215,6 +215,7 @@ azure: {{- end }} container_name: {{ $.Values.loki.storage.bucketNames.chunks }} use_managed_identity: {{ .useManagedIdentity }} + use_federated_token: {{ .useFederatedToken }} {{- with .userAssignedId }} user_assigned_id: {{ . }} {{- end }} @@ -281,6 +282,7 @@ azure: {{- end }} container_name: {{ $.Values.loki.storage.bucketNames.ruler }} use_managed_identity: {{ .useManagedIdentity }} + use_federated_token: {{ .useFederatedToken }} {{- with .userAssignedId }} user_assigned_id: {{ . }} {{- end }} diff --git a/production/helm/loki/values.yaml b/production/helm/loki/values.yaml index 0a45e530fbddb..f578a6429ba1a 100644 --- a/production/helm/loki/values.yaml +++ b/production/helm/loki/values.yaml @@ -229,6 +229,7 @@ loki: accountName: null accountKey: null useManagedIdentity: false + useFederatedToken: false userAssignedId: null requestTimeout: null filesystem: diff --git a/vendor/github.com/stretchr/testify/suite/doc.go b/vendor/github.com/stretchr/testify/suite/doc.go new file mode 100644 index 0000000000000..f91a245d3f8b4 --- /dev/null +++ b/vendor/github.com/stretchr/testify/suite/doc.go @@ -0,0 +1,65 @@ +// Package suite contains logic for creating testing suite structs +// and running the methods on those structs as tests. The most useful +// piece of this package is that you can create setup/teardown methods +// on your testing suites, which will run before/after the whole suite +// or individual tests (depending on which interface(s) you +// implement). +// +// A testing suite is usually built by first extending the built-in +// suite functionality from suite.Suite in testify. Alternatively, +// you could reproduce that logic on your own if you wanted (you +// just need to implement the TestingSuite interface from +// suite/interfaces.go). +// +// After that, you can implement any of the interfaces in +// suite/interfaces.go to add setup/teardown functionality to your +// suite, and add any methods that start with "Test" to add tests. +// Methods that do not match any suite interfaces and do not begin +// with "Test" will not be run by testify, and can safely be used as +// helper methods. +// +// Once you've built your testing suite, you need to run the suite +// (using suite.Run from testify) inside any function that matches the +// identity that "go test" is already looking for (i.e. +// func(*testing.T)). +// +// Regular expression to select test suites specified command-line +// argument "-run". Regular expression to select the methods +// of test suites specified command-line argument "-m". +// Suite object has assertion methods. +// +// A crude example: +// // Basic imports +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/suite" +// ) +// +// // Define the suite, and absorb the built-in basic suite +// // functionality from testify - including a T() method which +// // returns the current testing context +// type ExampleTestSuite struct { +// suite.Suite +// VariableThatShouldStartAtFive int +// } +// +// // Make sure that VariableThatShouldStartAtFive is set to five +// // before each test +// func (suite *ExampleTestSuite) SetupTest() { +// suite.VariableThatShouldStartAtFive = 5 +// } +// +// // All methods that begin with "Test" are run as tests within a +// // suite. +// func (suite *ExampleTestSuite) TestExample() { +// assert.Equal(suite.T(), 5, suite.VariableThatShouldStartAtFive) +// suite.Equal(5, suite.VariableThatShouldStartAtFive) +// } +// +// // In order for 'go test' to run this suite, we need to create +// // a normal test function and pass our suite to suite.Run +// func TestExampleTestSuite(t *testing.T) { +// suite.Run(t, new(ExampleTestSuite)) +// } +package suite diff --git a/vendor/github.com/stretchr/testify/suite/interfaces.go b/vendor/github.com/stretchr/testify/suite/interfaces.go new file mode 100644 index 0000000000000..8b98a8af275f7 --- /dev/null +++ b/vendor/github.com/stretchr/testify/suite/interfaces.go @@ -0,0 +1,53 @@ +package suite + +import "testing" + +// TestingSuite can store and return the current *testing.T context +// generated by 'go test'. +type TestingSuite interface { + T() *testing.T + SetT(*testing.T) +} + +// SetupAllSuite has a SetupSuite method, which will run before the +// tests in the suite are run. +type SetupAllSuite interface { + SetupSuite() +} + +// SetupTestSuite has a SetupTest method, which will run before each +// test in the suite. +type SetupTestSuite interface { + SetupTest() +} + +// TearDownAllSuite has a TearDownSuite method, which will run after +// all the tests in the suite have been run. +type TearDownAllSuite interface { + TearDownSuite() +} + +// TearDownTestSuite has a TearDownTest method, which will run after +// each test in the suite. +type TearDownTestSuite interface { + TearDownTest() +} + +// BeforeTest has a function to be executed right before the test +// starts and receives the suite and test names as input +type BeforeTest interface { + BeforeTest(suiteName, testName string) +} + +// AfterTest has a function to be executed right after the test +// finishes and receives the suite and test names as input +type AfterTest interface { + AfterTest(suiteName, testName string) +} + +// WithStats implements HandleStats, a function that will be executed +// when a test suite is finished. The stats contain information about +// the execution of that suite and its tests. +type WithStats interface { + HandleStats(suiteName string, stats *SuiteInformation) +} diff --git a/vendor/github.com/stretchr/testify/suite/stats.go b/vendor/github.com/stretchr/testify/suite/stats.go new file mode 100644 index 0000000000000..261da37f78fbc --- /dev/null +++ b/vendor/github.com/stretchr/testify/suite/stats.go @@ -0,0 +1,46 @@ +package suite + +import "time" + +// SuiteInformation stats stores stats for the whole suite execution. +type SuiteInformation struct { + Start, End time.Time + TestStats map[string]*TestInformation +} + +// TestInformation stores information about the execution of each test. +type TestInformation struct { + TestName string + Start, End time.Time + Passed bool +} + +func newSuiteInformation() *SuiteInformation { + testStats := make(map[string]*TestInformation) + + return &SuiteInformation{ + TestStats: testStats, + } +} + +func (s SuiteInformation) start(testName string) { + s.TestStats[testName] = &TestInformation{ + TestName: testName, + Start: time.Now(), + } +} + +func (s SuiteInformation) end(testName string, passed bool) { + s.TestStats[testName].End = time.Now() + s.TestStats[testName].Passed = passed +} + +func (s SuiteInformation) Passed() bool { + for _, stats := range s.TestStats { + if !stats.Passed { + return false + } + } + + return true +} diff --git a/vendor/github.com/stretchr/testify/suite/suite.go b/vendor/github.com/stretchr/testify/suite/suite.go new file mode 100644 index 0000000000000..895591878bf7f --- /dev/null +++ b/vendor/github.com/stretchr/testify/suite/suite.go @@ -0,0 +1,226 @@ +package suite + +import ( + "flag" + "fmt" + "os" + "reflect" + "regexp" + "runtime/debug" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var allTestsFilter = func(_, _ string) (bool, error) { return true, nil } +var matchMethod = flag.String("testify.m", "", "regular expression to select tests of the testify suite to run") + +// Suite is a basic testing suite with methods for storing and +// retrieving the current *testing.T context. +type Suite struct { + *assert.Assertions + mu sync.RWMutex + require *require.Assertions + t *testing.T +} + +// T retrieves the current *testing.T context. +func (suite *Suite) T() *testing.T { + suite.mu.RLock() + defer suite.mu.RUnlock() + return suite.t +} + +// SetT sets the current *testing.T context. +func (suite *Suite) SetT(t *testing.T) { + suite.mu.Lock() + defer suite.mu.Unlock() + suite.t = t + suite.Assertions = assert.New(t) + suite.require = require.New(t) +} + +// Require returns a require context for suite. +func (suite *Suite) Require() *require.Assertions { + suite.mu.Lock() + defer suite.mu.Unlock() + if suite.require == nil { + suite.require = require.New(suite.T()) + } + return suite.require +} + +// Assert returns an assert context for suite. Normally, you can call +// `suite.NoError(expected, actual)`, but for situations where the embedded +// methods are overridden (for example, you might want to override +// assert.Assertions with require.Assertions), this method is provided so you +// can call `suite.Assert().NoError()`. +func (suite *Suite) Assert() *assert.Assertions { + suite.mu.Lock() + defer suite.mu.Unlock() + if suite.Assertions == nil { + suite.Assertions = assert.New(suite.T()) + } + return suite.Assertions +} + +func recoverAndFailOnPanic(t *testing.T) { + r := recover() + failOnPanic(t, r) +} + +func failOnPanic(t *testing.T, r interface{}) { + if r != nil { + t.Errorf("test panicked: %v\n%s", r, debug.Stack()) + t.FailNow() + } +} + +// Run provides suite functionality around golang subtests. It should be +// called in place of t.Run(name, func(t *testing.T)) in test suite code. +// The passed-in func will be executed as a subtest with a fresh instance of t. +// Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName. +func (suite *Suite) Run(name string, subtest func()) bool { + oldT := suite.T() + defer suite.SetT(oldT) + return oldT.Run(name, func(t *testing.T) { + suite.SetT(t) + subtest() + }) +} + +// Run takes a testing suite and runs all of the tests attached +// to it. +func Run(t *testing.T, suite TestingSuite) { + defer recoverAndFailOnPanic(t) + + suite.SetT(t) + + var suiteSetupDone bool + + var stats *SuiteInformation + if _, ok := suite.(WithStats); ok { + stats = newSuiteInformation() + } + + tests := []testing.InternalTest{} + methodFinder := reflect.TypeOf(suite) + suiteName := methodFinder.Elem().Name() + + for i := 0; i < methodFinder.NumMethod(); i++ { + method := methodFinder.Method(i) + + ok, err := methodFilter(method.Name) + if err != nil { + fmt.Fprintf(os.Stderr, "testify: invalid regexp for -m: %s\n", err) + os.Exit(1) + } + + if !ok { + continue + } + + if !suiteSetupDone { + if stats != nil { + stats.Start = time.Now() + } + + if setupAllSuite, ok := suite.(SetupAllSuite); ok { + setupAllSuite.SetupSuite() + } + + suiteSetupDone = true + } + + test := testing.InternalTest{ + Name: method.Name, + F: func(t *testing.T) { + parentT := suite.T() + suite.SetT(t) + defer recoverAndFailOnPanic(t) + defer func() { + r := recover() + + if stats != nil { + passed := !t.Failed() && r == nil + stats.end(method.Name, passed) + } + + if afterTestSuite, ok := suite.(AfterTest); ok { + afterTestSuite.AfterTest(suiteName, method.Name) + } + + if tearDownTestSuite, ok := suite.(TearDownTestSuite); ok { + tearDownTestSuite.TearDownTest() + } + + suite.SetT(parentT) + failOnPanic(t, r) + }() + + if setupTestSuite, ok := suite.(SetupTestSuite); ok { + setupTestSuite.SetupTest() + } + if beforeTestSuite, ok := suite.(BeforeTest); ok { + beforeTestSuite.BeforeTest(methodFinder.Elem().Name(), method.Name) + } + + if stats != nil { + stats.start(method.Name) + } + + method.Func.Call([]reflect.Value{reflect.ValueOf(suite)}) + }, + } + tests = append(tests, test) + } + if suiteSetupDone { + defer func() { + if tearDownAllSuite, ok := suite.(TearDownAllSuite); ok { + tearDownAllSuite.TearDownSuite() + } + + if suiteWithStats, measureStats := suite.(WithStats); measureStats { + stats.End = time.Now() + suiteWithStats.HandleStats(suiteName, stats) + } + }() + } + + runTests(t, tests) +} + +// Filtering method according to set regular expression +// specified command-line argument -m +func methodFilter(name string) (bool, error) { + if ok, _ := regexp.MatchString("^Test", name); !ok { + return false, nil + } + return regexp.MatchString(*matchMethod, name) +} + +func runTests(t testing.TB, tests []testing.InternalTest) { + if len(tests) == 0 { + t.Log("warning: no tests to run") + return + } + + r, ok := t.(runner) + if !ok { // backwards compatibility with Go 1.6 and below + if !testing.RunTests(allTestsFilter, tests) { + t.Fail() + } + return + } + + for _, test := range tests { + r.Run(test.Name, test.F) + } +} + +type runner interface { + Run(name string, f func(t *testing.T)) bool +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 07eae41fecb26..da5965dfe1713 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -1151,6 +1151,7 @@ github.com/stretchr/objx github.com/stretchr/testify/assert github.com/stretchr/testify/mock github.com/stretchr/testify/require +github.com/stretchr/testify/suite # github.com/thanos-io/objstore v0.0.0-20220715165016-ce338803bc1e ## explicit; go 1.17 github.com/thanos-io/objstore