diff --git a/protocol/app/upgrades.go b/protocol/app/upgrades.go index df4bf3eb38..3103802ac6 100644 --- a/protocol/app/upgrades.go +++ b/protocol/app/upgrades.go @@ -30,6 +30,7 @@ func (app *App) setupUpgradeHandlers() { v7_0_0.CreateUpgradeHandler( app.ModuleManager, app.configurator, + app.AccountKeeper, app.PricesKeeper, app.VaultKeeper, ), diff --git a/protocol/app/upgrades/v7.0.0/upgrade.go b/protocol/app/upgrades/v7.0.0/upgrade.go index d5a0b8e970..afc87cfada 100644 --- a/protocol/app/upgrades/v7.0.0/upgrade.go +++ b/protocol/app/upgrades/v7.0.0/upgrade.go @@ -8,6 +8,8 @@ import ( upgradetypes "cosmossdk.io/x/upgrade/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/module" + authkeeper "github.com/cosmos/cosmos-sdk/x/auth/keeper" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" "github.com/dydxprotocol/v4-chain/protocol/lib" "github.com/dydxprotocol/v4-chain/protocol/lib/slinky" pricestypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types" @@ -20,6 +22,75 @@ const ( QUOTE_QUANTUMS_PER_MEGAVAULT_SHARE = 1_000_000 ) +var ( + ModuleAccsToInitialize = []string{ + vaulttypes.MegavaultAccountName, + } +) + +// This module account initialization logic is copied from v3.0.0 upgrade handler. +func initializeModuleAccs(ctx sdk.Context, ak authkeeper.AccountKeeper) { + for _, modAccName := range ModuleAccsToInitialize { + // Get module account and relevant permissions from the accountKeeper. + addr, perms := ak.GetModuleAddressAndPermissions(modAccName) + if addr == nil { + panic(fmt.Sprintf( + "Did not find %v in `ak.GetModuleAddressAndPermissions`. This is not expected. Skipping.", + modAccName, + )) + } + + // Try to get the account in state. + acc := ak.GetAccount(ctx, addr) + if acc != nil { + // Account has been initialized. + macc, isModuleAccount := acc.(sdk.ModuleAccountI) + if isModuleAccount { + // Module account was correctly initialized. Skipping + ctx.Logger().Info(fmt.Sprintf( + "module account %+v was correctly initialized. No-op", + macc, + )) + continue + } + // Module account has been initialized as a BaseAccount. Change to module account. + // Note: We need to get the base account to retrieve its account number, and convert it + // in place into a module account. + baseAccount, ok := acc.(*authtypes.BaseAccount) + if !ok { + panic(fmt.Sprintf( + "cannot cast %v into a BaseAccount, acc = %+v", + modAccName, + acc, + )) + } + newModuleAccount := authtypes.NewModuleAccount( + baseAccount, + modAccName, + perms..., + ) + ak.SetModuleAccount(ctx, newModuleAccount) + ctx.Logger().Info(fmt.Sprintf( + "Successfully converted %v to module account in state: %+v", + modAccName, + newModuleAccount, + )) + continue + } + + // Account has not been initialized at all. Initialize it as module. + // Implementation taken from + // https://github.com/dydxprotocol/cosmos-sdk/blob/bdf96fdd/x/auth/keeper/keeper.go#L213 + newModuleAccount := authtypes.NewEmptyModuleAccount(modAccName, perms...) + maccI := (ak.NewAccount(ctx, newModuleAccount)).(sdk.ModuleAccountI) // this set the account number + ak.SetModuleAccount(ctx, maccI) + ctx.Logger().Info(fmt.Sprintf( + "Successfully initialized module account in state: %+v", + newModuleAccount, + )) + } +} + func initCurrencyPairIDCache(ctx sdk.Context, k pricestypes.PricesKeeper) { marketParams := k.GetAllMarketParams(ctx) for _, mp := range marketParams { @@ -126,6 +197,7 @@ func migrateVaultSharesToMegavaultShares(ctx sdk.Context, k vaultkeeper.Keeper) func CreateUpgradeHandler( mm *module.Manager, configurator module.Configurator, + accountKeeper authkeeper.AccountKeeper, pricesKeeper pricestypes.PricesKeeper, vaultKeeper vaultkeeper.Keeper, ) upgradetypes.UpgradeHandler { @@ -133,6 +205,9 @@ func CreateUpgradeHandler( sdkCtx := lib.UnwrapSDKContext(ctx, "app/upgrades") sdkCtx.Logger().Info(fmt.Sprintf("Running %s Upgrade...", UpgradeName)) + // Initialize module accounts. + initializeModuleAccs(sdkCtx, accountKeeper) + // Initialize the currency pair ID cache for all existing market params. initCurrencyPairIDCache(sdkCtx, pricesKeeper) diff --git a/protocol/app/upgrades/v7.0.0/upgrade_container_test.go b/protocol/app/upgrades/v7.0.0/upgrade_container_test.go index ccc0c79883..dc8ef454ff 100644 --- a/protocol/app/upgrades/v7.0.0/upgrade_container_test.go +++ b/protocol/app/upgrades/v7.0.0/upgrade_container_test.go @@ -8,6 +8,7 @@ import ( "github.com/cosmos/gogoproto/proto" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" v_7_0_0 "github.com/dydxprotocol/v4-chain/protocol/app/upgrades/v7.0.0" "github.com/dydxprotocol/v4-chain/protocol/dtypes" "github.com/dydxprotocol/v4-chain/protocol/testing/containertest" @@ -48,9 +49,12 @@ func preUpgradeChecks(node *containertest.Node, t *testing.T) { } func postUpgradeChecks(node *containertest.Node, t *testing.T) { - // Add test for your upgrade handler logic below + // Check that vault quoting params are successfully migrated to vault params. postUpgradeVaultParamsCheck(node, t) + // Check that vault shares are successfully migrated to megavault shares. postUpgradeMegavaultSharesCheck(node, t) + // Check that megavault module account is successfully initialized. + postUpgradeMegavaultModuleAccCheck(node, t) // Check that the affiliates module has been initialized with the default tiers. postUpgradeAffiliatesModuleTiersCheck(node, t) @@ -182,3 +186,20 @@ func postUpgradeAffiliatesModuleTiersCheck(node *containertest.Node, t *testing. require.NoError(t, err) require.Equal(t, affiliatestypes.DefaultAffiliateTiers, affiliateTiersResp.Tiers) } + +func postUpgradeMegavaultModuleAccCheck(node *containertest.Node, t *testing.T) { + resp, err := containertest.Query( + node, + authtypes.NewQueryClient, + authtypes.QueryClient.ModuleAccountByName, + &authtypes.QueryModuleAccountByNameRequest{ + Name: vaulttypes.MegavaultAccountName, + }, + ) + require.NoError(t, err) + require.NotNil(t, resp) + + moduleAccResp := authtypes.QueryModuleAccountByNameResponse{} + err = proto.UnmarshalText(resp.String(), &moduleAccResp) + require.NoError(t, err) +}