From 78e89fe53278a11c3e4f86ac7b921313bf54dfb4 Mon Sep 17 00:00:00 2001 From: Hayden Spitzley Date: Thu, 4 Apr 2024 21:56:40 -0600 Subject: [PATCH] refactor and fix tests --- oidc_cli/oidc_impl/cache/cache.go | 54 +++++++++++++++++--------- oidc_cli/oidc_impl/cache/cache_test.go | 14 +++++-- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/oidc_cli/oidc_impl/cache/cache.go b/oidc_cli/oidc_impl/cache/cache.go index 468dc28c..b973fe1d 100644 --- a/oidc_cli/oidc_impl/cache/cache.go +++ b/oidc_cli/oidc_impl/cache/cache.go @@ -65,7 +65,6 @@ func (c *Cache) refresh(ctx context.Context) (*client.Token, error) { if err != nil { return nil, err } - // if we have a valid token, use it if cachedToken.IsFresh() { return cachedToken, nil @@ -89,17 +88,13 @@ func (c *Cache) refresh(ctx context.Context) (*client.Token, error) { return nil, errors.Wrap(err, "unable to marshall token") } - // compress and save token to storage - var buf bytes.Buffer - gz := gzip.NewWriter(&buf) - if _, err := gz.Write([]byte(strToken)); err != nil { - return nil, fmt.Errorf("failed to write to gzip: %w", err) - } - if err := gz.Close(); err != nil { - return nil, fmt.Errorf("failed to close gzip: %w", err) + // gzip encode and save token to storage + compressedToken, err := compressToken(strToken) + if err != nil { + return nil, errors.Wrap(err, "unable to compress token") } - err = c.storage.Set(ctx, buf.String()) + err = c.storage.Set(ctx, compressedToken) if err != nil { return nil, errors.Wrap(err, "Unable to cache the strToken") } @@ -114,22 +109,15 @@ func (c *Cache) readFromStorage(ctx context.Context) (*client.Token, error) { if err != nil { return nil, err } - if cached == nil { return nil, nil } // decode gzip data - reader := bytes.NewReader([]byte(*cached)) - gzreader, err := gzip.NewReader(reader) - if err != nil { - return nil, fmt.Errorf("failed to create gzip reader: %w", err) - } - decompressed, err := io.ReadAll(gzreader) + decompressedStr, err := decompressToken(*cached) if err != nil { - return nil, fmt.Errorf("failed to read gzip data: %w", err) + return nil, fmt.Errorf("failed to decompress token: %w", err) } - decompressedStr := string(decompressed) cachedToken, err := client.TokenFromString(&decompressedStr) if err != nil { @@ -141,3 +129,31 @@ func (c *Cache) readFromStorage(ctx context.Context) (*client.Token, error) { } return cachedToken, nil } + +func compressToken(token string) (string, error) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + if _, err := gz.Write([]byte(token)); err != nil { + return "", fmt.Errorf("failed to write to gzip: %w", err) + } + if err := gz.Close(); err != nil { + return "", fmt.Errorf("failed to close gzip: %w", err) + } + return buf.String(), nil +} + +func decompressToken(token string) (string, error) { + reader := bytes.NewReader([]byte(token)) + gzreader, err := gzip.NewReader(reader) + if err != nil { + return "", fmt.Errorf("failed to create gzip reader: %w", err) + } + decompressed, err := io.ReadAll(gzreader) + if err != nil { + return "", fmt.Errorf("failed to read gzip data: %w", err) + } + if err := gzreader.Close(); err != nil { + return "", fmt.Errorf("failed to close gzip: %w", err) + } + return string(decompressed), nil +} diff --git a/oidc_cli/oidc_impl/cache/cache_test.go b/oidc_cli/oidc_impl/cache/cache_test.go index 2a9f76db..5016a500 100644 --- a/oidc_cli/oidc_impl/cache/cache_test.go +++ b/oidc_cli/oidc_impl/cache/cache_test.go @@ -57,7 +57,9 @@ func TestCorruptedCache(t *testing.T) { r := require.New(t) s := genStorage() ctx := context.Background() - err := s.Set(ctx, "garbage token") + compressed, err := compressToken("garbage token") + r.NoError(err) + err = s.Set(ctx, compressed) r.NoError(err) u := uuid.New() @@ -83,7 +85,10 @@ func TestCorruptedCache(t *testing.T) { r.NoError(err) r.NotNil(cachedToken) - tok, err := client.TokenFromString(cachedToken) + decompressedToken, err := decompressToken(*cachedToken) + r.NoError(err) + + tok, err := client.TokenFromString(&decompressedToken) r.NoError(err) r.NotNil(t) @@ -110,7 +115,10 @@ func TestCachedToken(t *testing.T) { marshalled, err := freshToken.Marshal() r.NoError(err) - err = s.Set(ctx, marshalled) + compressed, err := compressToken(marshalled) + r.NoError(err) + + err = s.Set(ctx, compressed) r.NoError(err) refresh := func(ctx context.Context, c *client.Token) (*client.Token, error) {