diff --git a/pkg/clients/authenticator.go b/pkg/clients/authenticator.go new file mode 100644 index 00000000..77a77b5f --- /dev/null +++ b/pkg/clients/authenticator.go @@ -0,0 +1,8 @@ +package clients + +// Authenticator is an interface for handling authentication with a service +type Authenticator interface { + GetToken() string + RefreshToken() error + NeedsTokenRefresh() error +} diff --git a/pkg/clients/cloud_client.go b/pkg/clients/cloud_client.go index 47773d79..eda74170 100644 --- a/pkg/clients/cloud_client.go +++ b/pkg/clients/cloud_client.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "strconv" "strings" @@ -41,19 +40,19 @@ type CloudProviderResponse struct { // CloudAPIClient is a client for interacting with the Materialize Cloud API type CloudAPIClient struct { - HTTPClient *http.Client - FronteggClient *FronteggClient - Endpoint string - BaseEndpoint string + HTTPClient *http.Client + Authenticator Authenticator + Endpoint string + BaseEndpoint string } // NewCloudAPIClient creates a new Cloud API client -func NewCloudAPIClient(fronteggClient *FronteggClient, cloudAPIEndpoint, baseEndpoint string) *CloudAPIClient { +func NewCloudAPIClient(authenticator Authenticator, cloudAPIEndpoint, baseEndpoint string) *CloudAPIClient { return &CloudAPIClient{ - HTTPClient: &http.Client{}, - FronteggClient: fronteggClient, - Endpoint: cloudAPIEndpoint, - BaseEndpoint: baseEndpoint, + HTTPClient: &http.Client{}, + Authenticator: authenticator, + Endpoint: cloudAPIEndpoint, + BaseEndpoint: baseEndpoint, } } @@ -61,28 +60,17 @@ func NewCloudAPIClient(fronteggClient *FronteggClient, cloudAPIEndpoint, baseEnd func (c *CloudAPIClient) ListCloudProviders(ctx context.Context) ([]CloudProvider, error) { providersEndpoint := fmt.Sprintf("%s/api/cloud-regions", c.Endpoint) - // Reuse the FronteggClient's HTTPClient which already includes the Authorization token. - resp, err := c.FronteggClient.HTTPClient.Get(providersEndpoint) + resp, err := c.doRequest(ctx, http.MethodGet, providersEndpoint, nil) if err != nil { - return nil, fmt.Errorf("error listing cloud providers: %v", err) + return nil, fmt.Errorf("error listing cloud providers: %w", err) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %v", err) - } - return nil, fmt.Errorf("cloud API returned non-200 status code: %d, body: %s", resp.StatusCode, string(body)) - } - var response CloudProviderResponse if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - return nil, err + return nil, fmt.Errorf("error decoding response: %w", err) } - log.Printf("[DEBUG] Cloud providers response body: %+v\n", response) - return response.Data, nil } @@ -90,29 +78,17 @@ func (c *CloudAPIClient) ListCloudProviders(ctx context.Context) ([]CloudProvide func (c *CloudAPIClient) GetRegionDetails(ctx context.Context, provider CloudProvider) (*CloudRegion, error) { regionEndpoint := fmt.Sprintf("%s/api/region", provider.Url) - resp, err := c.FronteggClient.HTTPClient.Get(regionEndpoint) + resp, err := c.doRequest(ctx, http.MethodGet, regionEndpoint, nil) if err != nil { - return nil, fmt.Errorf("error retrieving region details: %v", err) + return nil, fmt.Errorf("error retrieving region details: %w", err) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %v", err) - } - return nil, fmt.Errorf("cloud API returned non-200 status code: %d, body: %s", resp.StatusCode, string(body)) - } - - log.Printf("[DEBUG] Region details response body: %+v\n", resp.Body) - var region CloudRegion if err := json.NewDecoder(resp.Body).Decode(®ion); err != nil { - return nil, err + return nil, fmt.Errorf("error decoding region details: %w", err) } - log.Printf("[DEBUG] Region details response body: %+v\n", region) - return ®ion, nil } @@ -120,30 +96,16 @@ func (c *CloudAPIClient) GetRegionDetails(ctx context.Context, provider CloudPro func (c *CloudAPIClient) EnableRegion(ctx context.Context, provider CloudProvider) (*CloudRegion, error) { endpoint := fmt.Sprintf("%s/api/region", provider.Url) emptyJSONPayload := bytes.NewBuffer([]byte("{}")) - req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, emptyJSONPayload) - if err != nil { - return nil, fmt.Errorf("error creating request to enable region: %v", err) - } - req.Header.Add("Content-Type", "application/json") - - resp, err := c.FronteggClient.HTTPClient.Do(req) + resp, err := c.doRequest(ctx, http.MethodPatch, endpoint, emptyJSONPayload) if err != nil { - return nil, fmt.Errorf("error sending request to enable region: %v", err) + return nil, fmt.Errorf("error enabling region: %w", err) } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %v", err) - } - return nil, fmt.Errorf("cloud API returned non-200/201 status code: %d, body: %s", resp.StatusCode, string(body)) - } - var region CloudRegion if err := json.NewDecoder(resp.Body).Decode(®ion); err != nil { - return nil, err + return nil, fmt.Errorf("error decoding enabled region details: %w", err) } return ®ion, nil @@ -198,3 +160,44 @@ func SplitHostPort(hostPortStr string) (host string, port int, err error) { return "", 0, fmt.Errorf("invalid host:port format") } } + +func (c *CloudAPIClient) doRequest(ctx context.Context, method, url string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, fmt.Errorf("error creating request: %w", err) + } + + if err := c.Authenticator.NeedsTokenRefresh(); err != nil { + if err := c.Authenticator.RefreshToken(); err != nil { + return nil, fmt.Errorf("error refreshing token: %w", err) + } + } + + req.Header.Set("Authorization", "Bearer "+c.Authenticator.GetToken()) + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("error sending request: %w", err) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, &APIError{ + StatusCode: resp.StatusCode, + Message: string(body), + } + } + + return resp, nil +} + +type APIError struct { + StatusCode int + Message string +} + +func (e *APIError) Error() string { + return fmt.Sprintf("API error: %d - %s", e.StatusCode, e.Message) +} diff --git a/pkg/clients/cloud_client_test.go b/pkg/clients/cloud_client_test.go index 60dc6ab7..7887bd45 100644 --- a/pkg/clients/cloud_client_test.go +++ b/pkg/clients/cloud_client_test.go @@ -15,9 +15,31 @@ import ( type MockFronteggService struct { MockResponseStatus int + LastRequest *http.Request +} + +type MockAuthenticator struct { + Token string + RefreshCalled bool + NeedsRefreshCalled bool +} + +func (m *MockAuthenticator) GetToken() string { + return m.Token +} + +func (m *MockAuthenticator) RefreshToken() error { + m.RefreshCalled = true + return nil +} + +func (m *MockAuthenticator) NeedsTokenRefresh() error { + m.NeedsRefreshCalled = true + return nil } func (m *MockFronteggService) RoundTrip(req *http.Request) (*http.Response, error) { + m.LastRequest = req // Check the requested URL and return a response accordingly if strings.HasSuffix(req.URL.Path, "/api/cloud-regions") { // Mock response data @@ -62,22 +84,20 @@ func TestCloudAPIClient_ListCloudProviders(t *testing.T) { MockResponseStatus: http.StatusOK, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } - // Call the method to test providers, err := apiClient.ListCloudProviders(context.Background()) - if err != nil { - t.Fatalf("ListCloudProviders() error: %v", err) - } + require.NoError(t, err) + require.Len(t, providers, 2) - // Verify the results - wantProviderCount := 2 - if len(providers) != wantProviderCount { - t.Errorf("ListCloudProviders() got %v providers, want %v", len(providers), wantProviderCount) - } + require.True(t, mockAuthenticator.NeedsRefreshCalled) + require.Equal(t, "Bearer mock-token", mockService.LastRequest.Header.Get("Authorization")) } func TestCloudAPIClient_GetRegionDetails(t *testing.T) { @@ -85,9 +105,12 @@ func TestCloudAPIClient_GetRegionDetails(t *testing.T) { MockResponseStatus: http.StatusOK, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } provider := CloudProvider{ @@ -96,17 +119,12 @@ func TestCloudAPIClient_GetRegionDetails(t *testing.T) { Url: "http://mockendpoint.com/api/region", } - // Call the method to test region, err := apiClient.GetRegionDetails(context.Background(), provider) - if err != nil { - t.Fatalf("GetRegionDetails() error: %v", err) - } + require.NoError(t, err) + require.Equal(t, "sql.materialize.com", region.RegionInfo.SqlAddress) - // Verify the results - wantSqlAddress := "sql.materialize.com" - if region.RegionInfo.SqlAddress != wantSqlAddress { - t.Errorf("GetRegionDetails() got SqlAddress = %v, want %v", region.RegionInfo.SqlAddress, wantSqlAddress) - } + require.True(t, mockAuthenticator.NeedsRefreshCalled) + require.Equal(t, "Bearer mock-token", mockService.LastRequest.Header.Get("Authorization")) } func TestCloudAPIClient_GetHost(t *testing.T) { @@ -114,63 +132,60 @@ func TestCloudAPIClient_GetHost(t *testing.T) { MockResponseStatus: http.StatusOK, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } regionID := "aws/us-east-1" sqlAddress, err := apiClient.GetHost(context.Background(), regionID) - if err != nil { - t.Fatalf("GetHost() error: %v", err) - } + require.NoError(t, err) + require.Equal(t, "sql.materialize.com", sqlAddress) - // Verify the results - wantSqlAddress := "sql.materialize.com" - if sqlAddress != wantSqlAddress { - t.Errorf("GetHost() got SqlAddress = %v, want %v", sqlAddress, wantSqlAddress) - } + require.True(t, mockAuthenticator.NeedsRefreshCalled) } func TestCloudAPIClient_ListCloudProviders_ErrorResponse(t *testing.T) { mockService := &MockFronteggService{ - // Mock the HTTP response to return an error status code: MockResponseStatus: http.StatusInternalServerError, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } - // Call the method to test _, err := apiClient.ListCloudProviders(context.Background()) - - // Verify that an error is returned when the server responds with an error status code require.Error(t, err) } func TestCloudAPIClient_GetRegionDetails_ErrorResponse(t *testing.T) { mockService := &MockFronteggService{ - // Mock the HTTP response to return an error status code MockResponseStatus: http.StatusInternalServerError, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } + provider := CloudProvider{ ID: "aws/us-east-1", Name: "us-east-1", Url: "http://mockendpoint.com/api/region", } - // Call the method to test _, err := apiClient.GetRegionDetails(context.Background(), provider) - - // Verify that an error is returned when the server responds with an error status code require.Error(t, err) } @@ -179,9 +194,12 @@ func TestCloudAPIClient_GetHost_RegionNotFound(t *testing.T) { MockResponseStatus: http.StatusOK, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } regionID := "non-existent-region" @@ -191,32 +209,23 @@ func TestCloudAPIClient_GetHost_RegionNotFound(t *testing.T) { // Verify that an error is returned when the region is not found require.Error(t, err) require.Contains(t, err.Error(), "provider for region 'non-existent-region' not found") + + // Verify that the authentication method was called + require.True(t, mockAuthenticator.NeedsRefreshCalled) } func TestNewCloudAPIClient(t *testing.T) { - // Create a FronteggClient instance for testing - fronteggClient := &FronteggClient{} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} - // Call the NewCloudAPIClient function with a custom API endpoint customEndpoint := "http://custom-endpoint.com/api" baseEndpoint := "http://cloud.frontegg.com" - cloudAPIClient := NewCloudAPIClient(fronteggClient, customEndpoint, baseEndpoint) + cloudAPIClient := NewCloudAPIClient(mockAuthenticator, customEndpoint, baseEndpoint) - // Assert that the returned CloudAPIClient has the expected properties require.NotNil(t, cloudAPIClient) - require.Equal(t, fronteggClient, cloudAPIClient.FronteggClient) + require.Equal(t, mockAuthenticator, cloudAPIClient.Authenticator) require.NotNil(t, cloudAPIClient.HTTPClient) require.Equal(t, customEndpoint, cloudAPIClient.Endpoint) - - // Call the NewCloudAPIClient function with a different custom API endpoint - anotherCustomEndpoint := "http://another-custom-endpoint.com/api" - cloudAPIClient = NewCloudAPIClient(fronteggClient, anotherCustomEndpoint, baseEndpoint) - - // Assert that the returned CloudAPIClient has the updated custom endpoint - require.NotNil(t, cloudAPIClient) - require.Equal(t, fronteggClient, cloudAPIClient.FronteggClient) - require.NotNil(t, cloudAPIClient.HTTPClient) - require.Equal(t, anotherCustomEndpoint, cloudAPIClient.Endpoint) + require.Equal(t, baseEndpoint, cloudAPIClient.BaseEndpoint) } func TestCloudAPIClient_EnableRegion_Success(t *testing.T) { @@ -224,9 +233,12 @@ func TestCloudAPIClient_EnableRegion_Success(t *testing.T) { MockResponseStatus: http.StatusOK, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } provider := CloudProvider{ @@ -242,6 +254,9 @@ func TestCloudAPIClient_EnableRegion_Success(t *testing.T) { require.Equal(t, "http.materialize.com", region.RegionInfo.HttpAddress) require.True(t, region.RegionInfo.Resolvable) require.Equal(t, "2021-01-01T00:00:00Z", region.RegionInfo.EnabledAt) + + require.True(t, mockAuthenticator.NeedsRefreshCalled) + require.Equal(t, "Bearer mock-token", mockService.LastRequest.Header.Get("Authorization")) } func TestCloudAPIClient_EnableRegion_Error(t *testing.T) { @@ -249,9 +264,12 @@ func TestCloudAPIClient_EnableRegion_Error(t *testing.T) { MockResponseStatus: http.StatusInternalServerError, } mockClient := &http.Client{Transport: mockService} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + apiClient := &CloudAPIClient{ - FronteggClient: &FronteggClient{HTTPClient: mockClient}, - Endpoint: "http://mockendpoint.com", + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", } provider := CloudProvider{ @@ -260,7 +278,6 @@ func TestCloudAPIClient_EnableRegion_Error(t *testing.T) { Url: "http://mockendpoint.com/api/region", } - // Simulate an error response for EnableRegion _, err := apiClient.EnableRegion(context.Background(), provider) require.Error(t, err) require.Contains(t, err.Error(), "cloud API returned non-200/201 status code:") diff --git a/pkg/datasources/datasource_region_test.go b/pkg/datasources/datasource_region_test.go index a76dc9f3..64d1f0ef 100644 --- a/pkg/datasources/datasource_region_test.go +++ b/pkg/datasources/datasource_region_test.go @@ -13,6 +13,26 @@ import ( "github.com/stretchr/testify/require" ) +type MockAuthenticator struct { + Token string + RefreshCalled bool + NeedsRefreshCalled bool +} + +func (m *MockAuthenticator) GetToken() string { + return m.Token +} + +func (m *MockAuthenticator) RefreshToken() error { + m.RefreshCalled = true + return nil +} + +func (m *MockAuthenticator) NeedsTokenRefresh() error { + m.NeedsRefreshCalled = true + return nil +} + func TestRegionRead(t *testing.T) { r := require.New(t) @@ -23,6 +43,8 @@ func TestRegionRead(t *testing.T) { Transport: &testhelpers.MockCloudService{}, } + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + fronteggClient := &clients.FronteggClient{ Endpoint: serverURL, HTTPClient: mockClient, @@ -30,8 +52,8 @@ func TestRegionRead(t *testing.T) { } // Create a mock cloud client mockCloudClient := &clients.CloudAPIClient{ - FronteggClient: fronteggClient, - Endpoint: serverURL, + Authenticator: mockAuthenticator, + Endpoint: serverURL, } // Create a provider meta with the mock cloud client diff --git a/pkg/resources/resource_region_test.go b/pkg/resources/resource_region_test.go index 41121660..1fb93539 100644 --- a/pkg/resources/resource_region_test.go +++ b/pkg/resources/resource_region_test.go @@ -4,7 +4,6 @@ import ( "context" "net/http" "testing" - "time" "github.com/MaterializeInc/terraform-provider-materialize/pkg/clients" "github.com/MaterializeInc/terraform-provider-materialize/pkg/testhelpers" @@ -13,34 +12,47 @@ import ( "github.com/stretchr/testify/require" ) +type MockAuthenticator struct { + Token string + RefreshCalled bool + NeedsRefreshCalled bool +} + +func (m *MockAuthenticator) GetToken() string { + return m.Token +} + +func (m *MockAuthenticator) RefreshToken() error { + m.RefreshCalled = true + return nil +} + +func (m *MockAuthenticator) NeedsTokenRefresh() error { + m.NeedsRefreshCalled = true + return nil +} + func TestResourceCloudRegionCreate(t *testing.T) { r := require.New(t) - // Set up the mock cloud server testhelpers.WithMockCloudServer(t, func(serverURL string) { - // Create an http.Client that uses the mock transport mockClient := &http.Client{ Transport: &testhelpers.MockCloudService{}, } - fronteggClient := &clients.FronteggClient{ - Endpoint: serverURL, - HTTPClient: mockClient, - TokenExpiry: time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC), - } - // Create a mock cloud client + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + mockCloudClient := &clients.CloudAPIClient{ - FronteggClient: fronteggClient, - Endpoint: serverURL, + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: serverURL, } - // Create a provider meta with the mock cloud client providerMeta := &utils.ProviderMeta{ - CloudAPI: mockCloudClient, - Frontegg: fronteggClient, + CloudAPI: mockCloudClient, + Authenticator: mockAuthenticator, } - // Create a test resource data with the Region schema d := schema.TestResourceDataRaw(t, regionSchema, map[string]interface{}{"region_id": "aws/us-east-1"}) diags := resourceCloudRegionCreate(context.Background(), d, providerMeta) @@ -56,6 +68,7 @@ func TestResourceCloudRegionCreate(t *testing.T) { r.Equal("sql.materialize.com", d.Get("sql_address")) r.True(d.Get("resolvable").(bool)) r.True(d.Get("region_state").(bool)) + r.True(mockAuthenticator.NeedsRefreshCalled) }) } @@ -67,20 +80,17 @@ func TestResourceCloudRegionRead(t *testing.T) { Transport: &testhelpers.MockCloudService{}, } - fronteggClient := &clients.FronteggClient{ - Endpoint: serverURL, - HTTPClient: mockClient, - TokenExpiry: time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC), - } + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} mockCloudClient := &clients.CloudAPIClient{ - FronteggClient: fronteggClient, - Endpoint: serverURL, + HTTPClient: mockClient, + Authenticator: mockAuthenticator, + Endpoint: serverURL, } providerMeta := &utils.ProviderMeta{ - CloudAPI: mockCloudClient, - Frontegg: fronteggClient, + CloudAPI: mockCloudClient, + Authenticator: mockAuthenticator, } d := schema.TestResourceDataRaw(t, regionSchema, map[string]interface{}{"region_id": "aws/us-east-1"}) @@ -99,6 +109,7 @@ func TestResourceCloudRegionRead(t *testing.T) { r.Equal("sql.materialize.com", d.Get("sql_address")) r.True(d.Get("resolvable").(bool)) r.True(d.Get("region_state").(bool)) + r.True(mockAuthenticator.NeedsRefreshCalled) }) } @@ -106,7 +117,16 @@ func TestResourceCloudRegionDelete(t *testing.T) { r := require.New(t) ctx := context.Background() - providerMeta := &utils.ProviderMeta{} + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} + mockCloudClient := &clients.CloudAPIClient{ + HTTPClient: &http.Client{}, + Authenticator: mockAuthenticator, + Endpoint: "http://mockendpoint.com", + } + providerMeta := &utils.ProviderMeta{ + CloudAPI: mockCloudClient, + Authenticator: mockAuthenticator, + } d := schema.TestResourceDataRaw(t, regionSchema, map[string]interface{}{"region_id": "aws/us-east-1"}) d.SetId("aws/us-east-1") diff --git a/pkg/resources/resource_sso_config_test.go b/pkg/resources/resource_sso_config_test.go index 155ea6e2..ea3307f7 100644 --- a/pkg/resources/resource_sso_config_test.go +++ b/pkg/resources/resource_sso_config_test.go @@ -4,7 +4,6 @@ import ( "context" "net/http" "testing" - "time" "github.com/MaterializeInc/terraform-provider-materialize/pkg/clients" "github.com/MaterializeInc/terraform-provider-materialize/pkg/testhelpers" @@ -29,20 +28,17 @@ func TestSSOConfigResourceCreate(t *testing.T) { r.NotNil(d) testhelpers.WithMockFronteggServer(t, func(serverURL string) { - client := &clients.FronteggClient{ - Endpoint: serverURL, - HTTPClient: &http.Client{}, - TokenExpiry: time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC), - } + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} mockCloudClient := &clients.CloudAPIClient{ - FronteggClient: client, - Endpoint: serverURL, + HTTPClient: &http.Client{}, + Authenticator: mockAuthenticator, + Endpoint: serverURL, } providerMeta := &utils.ProviderMeta{ - Frontegg: client, - CloudAPI: mockCloudClient, + Authenticator: mockAuthenticator, + CloudAPI: mockCloudClient, } if err := ssoConfigCreate(context.TODO(), d, providerMeta); err != nil { @@ -62,14 +58,10 @@ func TestSSOConfigResourceRead(t *testing.T) { r := require.New(t) testhelpers.WithMockFronteggServer(t, func(serverURL string) { - client := &clients.FronteggClient{ - Endpoint: serverURL, - HTTPClient: &http.Client{}, - TokenExpiry: time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC), - } + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} providerMeta := &utils.ProviderMeta{ - Frontegg: client, + Authenticator: mockAuthenticator, } d := schema.TestResourceDataRaw(t, SSOConfigSchema, nil) @@ -84,27 +76,23 @@ func TestSSOConfigResourceRead(t *testing.T) { r.Equal("https://sso.example.com", d.Get("sso_endpoint")) r.Equal("mock-public-certificate\n", d.Get("public_certificate")) }) - } func TestSSOConfigResourceUpdate(t *testing.T) { r := require.New(t) testhelpers.WithMockFronteggServer(t, func(serverURL string) { - client := &clients.FronteggClient{ - Endpoint: serverURL, - HTTPClient: &http.Client{}, - TokenExpiry: time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC), - } + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} mockCloudClient := &clients.CloudAPIClient{ - FronteggClient: client, - Endpoint: serverURL, + HTTPClient: &http.Client{}, + Authenticator: mockAuthenticator, + Endpoint: serverURL, } providerMeta := &utils.ProviderMeta{ - Frontegg: client, - CloudAPI: mockCloudClient, + Authenticator: mockAuthenticator, + CloudAPI: mockCloudClient, } d := schema.TestResourceDataRaw(t, SSOConfigSchema, nil) @@ -121,14 +109,10 @@ func TestSSOConfigResourceDelete(t *testing.T) { r := require.New(t) testhelpers.WithMockFronteggServer(t, func(serverURL string) { - client := &clients.FronteggClient{ - Endpoint: serverURL, - HTTPClient: &http.Client{}, - TokenExpiry: time.Date(9999, 1, 1, 0, 0, 0, 0, time.UTC), - } + mockAuthenticator := &MockAuthenticator{Token: "mock-token"} providerMeta := &utils.ProviderMeta{ - Frontegg: client, + Authenticator: mockAuthenticator, } d := schema.TestResourceDataRaw(t, SSOConfigSchema, nil) diff --git a/pkg/utils/provider_meta.go b/pkg/utils/provider_meta.go index 682b38e7..4254e68c 100644 --- a/pkg/utils/provider_meta.go +++ b/pkg/utils/provider_meta.go @@ -22,6 +22,10 @@ type ProviderMeta struct { // which may involve authentication, token management, etc. Frontegg *clients.FronteggClient + // Authenticator is the interface used to manage authentication and token + // management for the Frontegg API. + Authenticator clients.Authenticator + // CloudAPI is the client used for interactions with the cloud API CloudAPI *clients.CloudAPIClient