diff --git a/.hlint.yaml b/.hlint.yaml index 5ace677d51..6b7cd8c251 100644 --- a/.hlint.yaml +++ b/.hlint.yaml @@ -41,6 +41,22 @@ - suggest: {lhs: "(Data.Set.size x) == 0" , rhs: "Data.Set.null x"} - suggest: {lhs: "(Data.Set.size x) /= 0" , rhs: "not $ Data.Set.null x"} +- group: + name: stm + enabled: true + rules: + - hint: {lhs: Control.Concurrent.STM.takeTMVar, rhs: Control.Concurrent.STM.tryTakeTMVar, note: "Blocks if the TMVar is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.putTMVar, rhs: Control.Concurrent.STM.writeTMVar, note: "Blocks if the TMVar is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.readTMVar, rhs: Control.Concurrent.STM.tryReadTMVar, note: "Blocks if the TMVar is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TQueue.readTQueue, rhs: Control.Concurrent.STM.TQueue.tryReadTQueue, note: "Retries and blocks if the TBQueue is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TQueue.peekTQueue, rhs: Control.Concurrent.STM.TQueue.tryPeekTQueue, note: "Retries and blocks if the TBQueue is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TBQueue.readTBQueue, rhs: Control.Concurrent.STM.TBQueue.tryReadTBQueue, note: "Retries and blocks if the TBQueue is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TBQueue.peekTBQueue, rhs: Control.Concurrent.STM.TBQueue.tryPeekTBQueue, note: "Retries and blocks if the TBQueue is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TBQueue.unGetTBQueue, rhs: Control.Concurrent.STM.TBQueue.unGetTBQueue, note: "Retries and blocks if the TBQueue is full. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TBMQueue.readTBMQueue, rhs: Control.Concurrent.STM.TBMQueue.tryReadTBMQueue, note: "Retries and blocks if the TBMQueue is empty and open. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TBMQueue.peekTBMQueue, rhs: Control.Concurrent.STM.TBMQueue.tryPeekTBMQueue, note: "Retries and blocks if the TBMQueue is empty. Be sure this is correct behavior."} + - hint: {lhs: Control.Concurrent.STM.TBMQueue.writeTBMQueue, rhs: Control.Concurrent.STM.TBMQueue.tryWriteTBMQueue, note: "Retries and blocks if the TBMQueue is full. Be sure this is correct behavior."} + # Forbidden items, only allowed in compile-time code, or test code (however, it should be avoided in tests as much as possible). - functions: - {name: error, within: [Data.String.Conversion, Control.Effect.Replay]} diff --git a/Changelog.md b/Changelog.md index 77343fa42a..9a72fa49e6 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,9 +1,14 @@ # FOSSA CLI Changelog +## 3.8.37 + +- Container Scans: Bugfix for some registry scans that fail with an STM error. ([#1370](https://github.com/fossas/fossa-cli/pull/1370)) + ## v3.8.36 - `fossa feedback`: Allow users to provide feedback on their cli experience ([#1368](https://github.com/fossas/fossa-cli/pull/1368)) - Add preflight checks to validate API key, connection to FOSSA app, and ability to write to temp directory in relevant commands + ## v3.8.35 - Running `fossa analyze --detect-vendored` no longer fails if there are no detected vendored dependencies ([#1373](https://github.com/fossas/fossa-cli/pull/1373)). diff --git a/fourmolu.yaml b/fourmolu.yaml index c576011a2a..9e0536d964 100644 --- a/fourmolu.yaml +++ b/fourmolu.yaml @@ -31,3 +31,4 @@ fixities: - infixr 3 && - infix 4 == - infixl 4 <$>, <*> +- infixr 6 <> \ No newline at end of file diff --git a/integration-test/Container/AnalysisSpec.hs b/integration-test/Container/AnalysisSpec.hs new file mode 100644 index 0000000000..99072ba1bc --- /dev/null +++ b/integration-test/Container/AnalysisSpec.hs @@ -0,0 +1,56 @@ +{-# LANGUAGE OverloadedRecordDot #-} + +module Container.AnalysisSpec (spec) where + +import App.Fossa.Config.Common (ScanDestination (OutputStdout)) +import App.Fossa.Config.Container.Analyze (ContainerAnalyzeConfig (..)) +import App.Fossa.Config.Container.Common (ImageText (ImageText)) +import App.Fossa.Container.AnalyzeNative (analyzeExperimental) +import App.Types (OverrideProject (OverrideProject)) +import Container.FixtureUtils (runContainerEffs) +import Container.Types ( + ContainerScan (imageData, imageTag), + ContainerScanImage (imageLayers, imageOs, imageOsRelease), + ) +import Data.Flag (toFlag') +import Diag.Result (Result (..)) +import Effect.Logger (Severity (SevInfo)) +import Test.Hspec (Spec, aroundAll, describe, it, shouldBe, shouldSatisfy) + +spec :: Spec +spec = describe "Container Scanning" registrySourceAnalysis + +registrySourceCfg :: ContainerAnalyzeConfig +registrySourceCfg = + ContainerAnalyzeConfig + { scanDestination = OutputStdout + , revisionOverride = OverrideProject Nothing Nothing Nothing + , imageLocator = ImageText "public.ecr.aws/docker/library/alpine:3.19.1" + , jsonOutput = toFlag' False + , usesExperimentalScanner = True + , dockerHost = "" + , arch = "amd64" + , severity = SevInfo + , onlySystemDeps = False + , filterSet = mempty + } + +runAnalyze :: ContainerAnalyzeConfig -> (ContainerScan -> IO ()) -> IO () +runAnalyze analyzeCfg action = do + res <- runContainerEffs (analyzeExperimental analyzeCfg) + case res of + Failure _ errGroup -> fail . show $ errGroup + Success _ a -> action a + +registrySourceAnalysis :: Spec +registrySourceAnalysis = do + aroundAll (runAnalyze registrySourceCfg) $ do + describe "Container analysis from registry source" $ do + it "Has the correct OS" $ + \res -> res.imageData.imageOs `shouldBe` "alpine" + it "Has the correct OS release version" $ + \res -> res.imageData.imageOsRelease `shouldBe` "3.19.1" + it "Has the expected image tag" $ + \res -> res.imageTag `shouldBe` "public.ecr.aws/docker/library/alpine" + it "Has at least one layer" $ + \res -> res.imageData.imageLayers `shouldSatisfy` (not . null) diff --git a/integration-test/Container/FixtureUtils.hs b/integration-test/Container/FixtureUtils.hs new file mode 100644 index 0000000000..f228b65a39 --- /dev/null +++ b/integration-test/Container/FixtureUtils.hs @@ -0,0 +1,32 @@ +module Container.FixtureUtils ( + ContainerAnalysisC, + runContainerEffs, +) where + +import Control.Carrier.Diagnostics (DiagnosticsC, runDiagnostics) +import Control.Carrier.Stack (StackC, runStack) +import Control.Carrier.Telemetry (IgnoreTelemetryC, withoutTelemetry) +import Data.Function ((&)) +import Diag.Result (Result) +import Effect.Exec (ExecIOC, runExecIO) +import Effect.Logger (LoggerC, Severity (SevWarn), withDefaultLogger) +import Effect.ReadFS (ReadFSIOC, runReadFSIO) +import Type.Operator (type ($)) + +type ContainerAnalysisC m = + ExecIOC + $ ReadFSIOC + $ LoggerC + $ DiagnosticsC + $ StackC + $ IgnoreTelemetryC m + +runContainerEffs :: ContainerAnalysisC IO a -> IO (Result a) +runContainerEffs f = + f + & runExecIO + & runReadFSIO + & withDefaultLogger SevWarn + & runDiagnostics + & runStack + & withoutTelemetry diff --git a/spectrometer.cabal b/spectrometer.cabal index 2a02463577..0cff2b8f03 100644 --- a/spectrometer.cabal +++ b/spectrometer.cabal @@ -704,6 +704,8 @@ test-suite integration-tests Analysis.RustSpec Analysis.ScalaSpec Analysis.SwiftSpec + Container.AnalysisSpec + Container.FixtureUtils SpecHook build-tool-depends: hspec-discover:hspec-discover ^>=2.10.0.1 diff --git a/src/App/Fossa/Container/AnalyzeNative.hs b/src/App/Fossa/Container/AnalyzeNative.hs index 98bc3c648c..72d7e11ae6 100644 --- a/src/App/Fossa/Container/AnalyzeNative.hs +++ b/src/App/Fossa/Container/AnalyzeNative.hs @@ -80,8 +80,8 @@ analyzeExperimental :: , Has Telemetry sig m ) => ContainerAnalyzeConfig -> - m Aeson.Value -analyzeExperimental cfg = + m ContainerScan +analyzeExperimental cfg = do case Config.severity cfg of SevDebug -> do (scope, res) <- collectDebugBundle cfg $ Diag.errorBoundaryIO $ analyze cfg @@ -99,7 +99,7 @@ analyze :: , Has Debug sig m ) => ContainerAnalyzeConfig -> - m Aeson.Value + m ContainerScan analyze cfg = do _ <- case scanDestination cfg of OutputStdout -> pure () @@ -119,7 +119,7 @@ analyze cfg = do UploadScan apiOpts projectMeta -> void $ runFossaApiClient apiOpts $ uploadScan revision projectMeta (jsonOutput cfg) scannedImage - pure $ Aeson.toJSON scannedImage + pure scannedImage uploadScan :: ( Has Diagnostics sig m diff --git a/src/Control/Carrier/ContainerRegistryApi.hs b/src/Control/Carrier/ContainerRegistryApi.hs index ec90ce114e..675e4a9dd4 100644 --- a/src/Control/Carrier/ContainerRegistryApi.hs +++ b/src/Control/Carrier/ContainerRegistryApi.hs @@ -46,7 +46,6 @@ import Control.Carrier.ContainerRegistryApi.Common ( RegistryCtx (RegistryCtx), fromResponse, getContentType, - getToken, ) import Control.Carrier.Finally (runFinally) @@ -54,7 +53,7 @@ import Control.Carrier.Reader (ReaderC, ask, runReader) import Control.Carrier.Simple (SimpleC, interpret) import Control.Carrier.StickyLogger (runStickyLogger) import Control.Carrier.TaskPool (withTaskPool) -import Control.Concurrent (getNumCapabilities) +import Control.Concurrent (getNumCapabilities, myThreadId) import Control.Concurrent.STM (newEmptyTMVarIO) import Control.Effect.ContainerRegistryApi ( ContainerRegistryApiF (ExportImage, GetImageManifest), @@ -75,13 +74,16 @@ import Data.Aeson (eitherDecode, encode) import Data.ByteString (ByteString, writeFile) import Data.ByteString.Lazy qualified as ByteStringLazy import Data.Conduit.Zlib (ungzip) -import Data.Maybe (fromMaybe) +import Data.Maybe (fromMaybe, isNothing) import Data.String.Conversion ( LazyStrict (toStrict), + showText, toString, toText, ) import Data.Text (Text) +import Data.UUID qualified as UUID (toText) +import Data.UUID.V4 qualified as UUID (nextRandom) import Effect.Logger ( Logger, Pretty (pretty), @@ -251,31 +253,41 @@ exportBlob :: (RepoDigest, Bool, Text) -> m (Path Abs File) exportBlob manager imgSrc dir (digest, isGzip, targetFilename) = do - ctx <- ask - let sinkTarget :: Path Abs File - sinkTarget = dir Path (toString targetFilename) - - let imgSrc' = imgSrc{registryContainerRepositoryReference = RepoReferenceDigest digest} - - -- Prepare request with necessary authorization - req <- blobEndpoint imgSrc' - token <- getAuthToken (registryCred imgSrc) req manager Nothing =<< getToken ctx - let req' = applyAuthToken token req - - -- Download image artifact - sendIO . runResourceT $ do - response <- HTTPConduit.http req' manager - runConduit $ - HTTPConduit.responseBody response - .| (if isGzip then ungzip else idC) - .| sinkFile (toString sinkTarget) - - logInfo . pretty $ - if isGzip - then "Gzip extracted & downloaded: " <> targetFilename - else "Downloaded: " <> targetFilename - - pure sinkTarget + exportJobId <- sendIO UUID.nextRandom + threadId <- sendIO myThreadId + let exportDesc = "Export job ID: " <> UUID.toText exportJobId <> ", Export thread ID: " <> showText threadId + context exportDesc $ do + let sinkTarget :: Path Abs File + sinkTarget = dir Path (toString targetFilename) + + let imgSrc' = imgSrc{registryContainerRepositoryReference = RepoReferenceDigest digest} + + -- Prepare request with necessary authorization + req <- blobEndpoint imgSrc' + -- The current RegistryCtx is shared amongst multiple threads exporting blobs. + -- This could potentially be a problem if layers in a manifest file need different tokens to fetch. + -- I think the only way this *might* be possible is through redirects when fetching blobs. + -- I think the registry fetcher would still make progress in that case, but would just make more token reqs than necessary. + token <- getAuthToken (registryCred imgSrc) req manager Nothing =<< ask + -- This message generally means that auth is not required. + -- It may also indicate a bug in how we update/share tokens between threads. + when (isNothing token) $ logDebug "Got Nothing as a token." + let req' = applyAuthToken token req + + -- Download image artifact + sendIO . runResourceT $ do + response <- HTTPConduit.http req' manager + runConduit $ + HTTPConduit.responseBody response + .| (if isGzip then ungzip else idC) + .| sinkFile (toString sinkTarget) + + logInfo . pretty $ + if isGzip + then "Gzip extracted & downloaded: " <> targetFilename + else "Downloaded: " <> targetFilename + + pure sinkTarget -- | Identity Conduit idC :: (PrimMonad m) => ConduitT ByteString ByteString m () diff --git a/src/Control/Carrier/ContainerRegistryApi/Authorization.hs b/src/Control/Carrier/ContainerRegistryApi/Authorization.hs index 81142989cf..a414bb2e49 100644 --- a/src/Control/Carrier/ContainerRegistryApi/Authorization.hs +++ b/src/Control/Carrier/ContainerRegistryApi/Authorization.hs @@ -19,7 +19,7 @@ import Control.Carrier.ContainerRegistryApi.Common ( getToken, logHttp, originalReqUri, - updateToken, + safeReplaceToken, ) import Control.Carrier.ContainerRegistryApi.Errors ( ContainerRegistryApiErrorBody, @@ -29,15 +29,17 @@ import Control.Carrier.ContainerRegistryApi.Errors ( import Control.Effect.Diagnostics (Diagnostics, fatal, fatalText, fromMaybeText) import Control.Effect.Lift (Lift, sendIO) import Control.Effect.Reader (Reader, ask) -import Data.Aeson (FromJSON (parseJSON), decode', eitherDecode, withObject, (.:)) +import Control.Monad (unless, when) +import Data.Aeson (FromJSON (parseJSON), decode', eitherDecode, withObject, (.:), (.:?)) import Data.ByteString.Lazy qualified as ByteStringLazy import Data.Map (Map) import Data.Map qualified as Map +import Data.Maybe (isJust) import Data.String.Conversion (ConvertUtf8 (decodeUtf8), encodeUtf8, toString, toText) import Data.Text (Text, isInfixOf) import Data.Text qualified as Text import Data.Void (Void) -import Effect.Logger (Logger) +import Effect.Logger (Logger, Pretty (pretty), logDebug) import Network.HTTP.Client ( Manager, Request (host, method, shouldStripHeaderOnRedirect), @@ -79,9 +81,8 @@ mkRequest :: Request -> -- Request to make m (Response ByteStringLazy.ByteString) mkRequest manager registryCred accepts req = do - token <- getToken =<< ask - token' <- getAuthToken registryCred req manager accepts token - logHttp (applyContentType accepts $ applyAuthToken token' req) manager + token <- getAuthToken registryCred req manager accepts =<< ask + logHttp (applyContentType accepts $ applyAuthToken token req) manager applyContentType :: Maybe [Text] -> Request -> Request applyContentType c r = case c of @@ -110,7 +111,7 @@ stripAuthHeaderOnRedirect r = isAzure :: Bool isAzure = "azurecr.io" `isInfixOf` decodeUtf8 (host r) --- | Generates Auth Token For Request. +-- | Get an auth token for a given resource if necessary and update the RegistryCtx. -- -- Refer to: -- @@ -136,7 +137,6 @@ getAuthToken :: ( Has (Lift IO) sig m , Has Diagnostics sig m , Has Logger sig m - , Has (Reader RegistryCtx) sig m ) => Maybe (Text, Text) -> -- | Username and Password to user when retrieving authorization token @@ -146,11 +146,12 @@ getAuthToken :: -- | Manager to use for requests Maybe [Text] -> -- | Content-Types for Accept Header - Maybe AuthToken -> - -- | Existing Token (if any) + RegistryCtx -> + -- | The registry context to retrieve/update tokens in. m (Maybe AuthToken) -getAuthToken cred reqAttempt manager accepts token = do - let request' = applyContentType accepts (applyAuthForExistingToken $ reqAttempt{method = "HEAD"}) +getAuthToken cred reqAttempt manager accepts registryCtx = do + token <- getToken registryCtx + let request' = applyContentType accepts (applyAuthToken token $ reqAttempt{method = "HEAD"}) response <- logHttp request' manager case (decode' $ responseBody response, statusCode . responseStatus $ response) of @@ -158,7 +159,26 @@ getAuthToken cred reqAttempt manager accepts token = do -- meaning that our token is valid, or we do not require authorization token. (Nothing, 200) -> pure token (_, 401) -> do - case parse parseAuthChallenge "" <$> getHeaderValue hWWWAuthenticate (responseHeaders response) of + didReplace <- safeReplaceToken registryCtx (respondToChallenge response) + unless didReplace $ logDebug "Token is already being updated. Waiting..." + getToken registryCtx + + -- - + -- Other Errors + -- - + (Just (apiErrors :: ContainerRegistryApiErrorBody), _) -> fatal (originalReqUri response, apiErrors) + (Nothing, _) -> fatal $ UnknownApiError (originalReqUri response) $ responseStatus response + where + respondToChallenge :: + ( Has (Lift IO) sig m + , Has Logger sig m + , Has Diagnostics sig m + ) => + Response a -> + m AuthToken + respondToChallenge response = do + let rawChallenge = getHeaderValue hWWWAuthenticate (responseHeaders response) + case parse parseAuthChallenge "" <$> rawChallenge of -- - -- Did not receive valid auth challenge -- - @@ -176,19 +196,8 @@ getAuthToken cred reqAttempt manager accepts token = do -- registry context. -- - Just (Right authChallenge) -> do - token' <- getTokenFromAuthChallenge cred authChallenge manager - ctx <- ask - updateToken ctx token' - pure (Just token') - - -- - - -- Other Errors - -- - - (Just (apiErrors :: ContainerRegistryApiErrorBody), _) -> fatal (originalReqUri response, apiErrors) - (Nothing, _) -> fatal $ UnknownApiError (originalReqUri response) $ responseStatus response - where - applyAuthForExistingToken :: Request -> Request - applyAuthForExistingToken = applyAuthToken token + logDebug $ "Got auth challenge: " <> (pretty . show $ authChallenge) + getTokenFromAuthChallenge cred authChallenge manager -- | Retrieves Token from Authorization Server. -- @@ -224,7 +233,12 @@ getTokenFromAuthChallenge cred (BearerAuthChallenge (RegistryBearerChallenge url response <- fromResponse =<< logHttp req' manager case eitherDecode $ responseBody response of Left err -> fatal . FailedToParseAuthChallenge $ toText err - Right tokenResponse -> pure $ BearerAuthToken $ unToken tokenResponse + Right tokenResponse -> do + let expiry = expiresIn tokenResponse + when (isJust expiry) $ + logDebug $ + "Token expires in: " <> pretty expiry + pure $ BearerAuthToken $ unToken tokenResponse where -- \| Authorization Server Endpoint. authTokenEndpoint :: Has (Lift IO) sig m => m (Request) @@ -246,12 +260,18 @@ data RegistryBearerChallenge = RegistryBearerChallenge } deriving (Show, Eq, Ord) -newtype AuthChallengeResponse = AuthChallengeResponse {unToken :: Text} deriving (Eq, Show, Ord) +data AuthChallengeResponse = AuthChallengeResponse + { unToken :: Text + , expiresIn :: Maybe Int + } + deriving (Eq, Show, Ord) + instance FromJSON AuthChallengeResponse where parseJSON = withObject "AuthChallengeResponse" $ \o -> ( AuthChallengeResponse <$> o .: "token" <|> AuthChallengeResponse <$> o .: "access_token" ) + <*> o .:? "expires_in" -- | Parses Authorization Header. -- diff --git a/src/Control/Carrier/ContainerRegistryApi/Common.hs b/src/Control/Carrier/ContainerRegistryApi/Common.hs index 26abd53cbb..657ab9d16a 100644 --- a/src/Control/Carrier/ContainerRegistryApi/Common.hs +++ b/src/Control/Carrier/ContainerRegistryApi/Common.hs @@ -9,6 +9,7 @@ module Control.Carrier.ContainerRegistryApi.Common ( AuthToken (..), getToken, updateToken, + safeReplaceToken, ) where import Control.Algebra (Has) @@ -16,11 +17,13 @@ import Control.Carrier.ContainerRegistryApi.Errors ( ContainerRegistryApiErrorBody, UnknownApiError (UnknownApiError), ) -import Control.Concurrent.STM (STM, TMVar, atomically, putTMVar, tryReadTMVar) +import Control.Concurrent.STM (STM, TMVar, atomically, retry, tryReadTMVar, tryTakeTMVar, writeTMVar) import Control.Effect.Diagnostics (Diagnostics, fatal) +import Control.Effect.Exception (onException) import Control.Effect.Lift (Lift, sendIO) import Data.Aeson (decode') import Data.ByteString.Lazy qualified as ByteStringLazy +import Data.Functor (void) import Data.List (find) import Data.String.Conversion (ConvertUtf8 (encodeUtf8), decodeUtf8) import Data.Text (Text) @@ -42,7 +45,7 @@ logHttp :: (Has (Lift IO) sig m, Has Logger sig m) => Request -> Manager -> m (R logHttp req manager = do logDebug summarizeRequest resp <- sendIO $ httpLbs req manager - logDebug . summarizeResponse $ resp + logDebug $ summarizeResponse resp pure resp where summarizeRequest :: Doc AnsiStyle @@ -96,18 +99,58 @@ data AuthToken | BasicAuthToken Text Text deriving (Show, Eq, Ord) +data RegistryHandle + = -- | Another thread is working on fetching a new token. + Updating + | -- | This Token is ready to be used. + Ready AuthToken + -- | Wrapper for context - e.g. access token etc. newtype RegistryCtx = RegistryCtx - { registryAccessToken :: TMVar AuthToken - } + {registryAccessToken :: TMVar RegistryHandle} -- | Gets access token from registry context. +-- If there isn't one, returns 'Nothing'. +-- If the token is in the process of being updated, then wait for it. getToken :: (Has (Lift IO) sig m) => RegistryCtx -> m (Maybe AuthToken) -getToken = sendSTM . tryReadTMVar . registryAccessToken +getToken ctx = sendSTM $ do + m <- tryReadTMVar . registryAccessToken $ ctx + case m of + -- If something else is replacing the token, wait until a change and try again. + Just Updating -> retry + Nothing -> pure Nothing + Just (Ready tok) -> pure . Just $ tok -- | Updates access token from registry context. updateToken :: Has (Lift IO) sig m => RegistryCtx -> AuthToken -> m () -updateToken token newVal = sendSTM $ putTMVar (registryAccessToken token) newVal +updateToken token newVal = do + sendSTM $ writeTMVar (registryAccessToken token) (Ready newVal) + +-- | Try to replace the token in ctx after retrieving it using the given action. +-- If another thread is trying to replace the token, then do nothing. +-- Returns True when successfully written, or False otherwise. +safeReplaceToken :: Has (Lift IO) sig m => RegistryCtx -> m AuthToken -> m Bool +safeReplaceToken ctx getNewToken = do + let tokVar = registryAccessToken ctx + + (shouldUpdate, exceptionCleanupAction) <- sendSTM $ do + m <- tryReadTMVar (registryAccessToken ctx) + + case m of + Just Updating -> pure (False, pure ()) + -- Existing or new token, should replace. + t -> do + writeTMVar tokVar Updating -- Signal to other threads there's an update in progress + pure (True, maybe (void $ tryTakeTMVar tokVar) (writeTMVar tokVar) t) + + if shouldUpdate + then do + -- If there's some new exception, clean up by putting the old token back. + -- This gives other threads the opportunity to try to fetch a new token and exit gracefully. + newToken <- getNewToken `onException` (sendSTM exceptionCleanupAction) + updateToken ctx newToken + pure True + else pure False sendSTM :: Has (Lift IO) sig m => STM a -> m a sendSTM = sendIO . atomically diff --git a/src/Control/Carrier/Threaded.hs b/src/Control/Carrier/Threaded.hs index d00b6c78c5..31a9564f9e 100644 --- a/src/Control/Carrier/Threaded.hs +++ b/src/Control/Carrier/Threaded.hs @@ -1,3 +1,7 @@ +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +{-# HLINT ignore "Use writeTMVar" #-} +{-# HLINT ignore "Use tryReadTMVar" #-} module Control.Carrier.Threaded ( fork, kill, diff --git a/src/Data/Flag.hs b/src/Data/Flag.hs index 657a0883a2..fab617f008 100644 --- a/src/Data/Flag.hs +++ b/src/Data/Flag.hs @@ -4,6 +4,7 @@ module Data.Flag ( fromFlag, toFlag, flagOpt, + toFlag', ) where import Data.Aeson (ToJSON (toEncoding), defaultOptions, genericToEncoding) @@ -23,6 +24,10 @@ fromFlag _ = getFlag toFlag :: a -> Bool -> Flag a toFlag _ = Flag +-- | 'toFlag', but use type inference to fill in Flag's type variable. +toFlag' :: forall a. Bool -> Flag a +toFlag' = Flag @a + -- | optparse-applicative helper flagOpt :: a -> Mod FlagFields Bool -> Parser (Flag a) flagOpt a fields = toFlag a <$> switch fields