diff --git a/.github/workflows/publish-iscmagic.yml b/.github/workflows/publish-npm.yml similarity index 76% rename from .github/workflows/publish-iscmagic.yml rename to .github/workflows/publish-npm.yml index 577ed4b9a7..94fd91c2c3 100644 --- a/.github/workflows/publish-iscmagic.yml +++ b/.github/workflows/publish-npm.yml @@ -1,4 +1,4 @@ -name: Publish @iota/iscmagic +name: Publish @iota NPM packages on: workflow_call: @@ -6,6 +6,9 @@ on: version: required: true type: string + workingDirectory: + required: true + type: string secrets: NPM_TOKEN: required: true @@ -15,18 +18,18 @@ jobs: runs-on: ubuntu-latest defaults: run: - working-directory: ./packages/vm/core/evm/iscmagic + working-directory: ${{ inputs.workingDirectory }} steps: - uses: actions/checkout@v3 - - + - uses: actions/setup-node@v3 with: node-version: lts/* registry-url: 'https://registry.npmjs.org' scope: iota - - + - run: npm version ${{ inputs.version }} - - + - run: npm publish --access public env: NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index dfc7255c35..c0a1067c3a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,10 +115,16 @@ jobs: build-args: | BUILD_LD_FLAGS=-X=github.com/iotaledger/wasp/components/app.Version=${{ steps.tagger.outputs.tag }} - release-iscmagic: - uses: ./.github/workflows/publish-iscmagic.yml + + release-npm-packacges: needs: release-docker - with: - version: ${{ needs.release-docker.outputs.version }} - secrets: - NPM_TOKEN: ${{ secrets.NPM_TOKEN }} + runs-on: ubuntu-latest + strategy: + matrix: + workingDirectory: ['./packages/vm/core/evm/iscmagic', './tools/evm/iscutils'] + steps: + - name: Release NPM package + uses: ./.github/workflows/publish-npm.yml + with: + version: ${{ needs.release-docker.outputs.version }} + workingDirectory: ${{ matrix.workingDirectory }} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index c2ad20a3ae..f6e4505829 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -ARG GOLANG_IMAGE_TAG=1.20-bullseye +ARG GOLANG_IMAGE_TAG=1.21-bullseye # Build stage FROM golang:${GOLANG_IMAGE_TAG} AS build diff --git a/Dockerfile.noncached b/Dockerfile.noncached index f25f884341..2f9a52d9fa 100644 --- a/Dockerfile.noncached +++ b/Dockerfile.noncached @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -ARG GOLANG_IMAGE_TAG=1.20-bullseye +ARG GOLANG_IMAGE_TAG=1.21-bullseye # Build stage FROM golang:${GOLANG_IMAGE_TAG} AS build diff --git a/contracts/wasm/testwasmlib/go/testwasmlibimpl/funcs.go b/contracts/wasm/testwasmlib/go/testwasmlibimpl/funcs.go index 8c8130c99b..4d260a1aec 100644 --- a/contracts/wasm/testwasmlib/go/testwasmlibimpl/funcs.go +++ b/contracts/wasm/testwasmlib/go/testwasmlibimpl/funcs.go @@ -793,3 +793,20 @@ func viewCheckEthEmptyAddressAndAgentID(ctx wasmlib.ScViewContext, f *CheckEthEm func viewCheckEthInvalidEmptyAddressFromString(_ wasmlib.ScViewContext, _ *CheckEthInvalidEmptyAddressFromStringContext) { _ = wasmtypes.AddressFromString("0x00") } + +func funcActivate(ctx wasmlib.ScFuncContext, f *ActivateContext) { + f.State.Active().SetValue(true) + deposit := ctx.Allowance().BaseTokens() + transfer := wasmlib.ScTransferFromBaseTokens(deposit) + ctx.TransferAllowed(ctx.AccountID(), transfer) + delay := f.Params.Seconds().Value() + testwasmlib.ScFuncs.Deactivate(ctx).Func.Delay(delay).Post() +} + +func funcDeactivate(_ wasmlib.ScFuncContext, f *DeactivateContext) { + f.State.Active().SetValue(false) +} + +func viewGetActive(_ wasmlib.ScViewContext, f *GetActiveContext) { + f.Results.Active().SetValue(f.State.Active().Value()) +} diff --git a/contracts/wasm/testwasmlib/rs/testwasmlibimpl/src/funcs.rs b/contracts/wasm/testwasmlib/rs/testwasmlibimpl/src/funcs.rs index dc0abad283..9d99124df5 100644 --- a/contracts/wasm/testwasmlib/rs/testwasmlibimpl/src/funcs.rs +++ b/contracts/wasm/testwasmlib/rs/testwasmlibimpl/src/funcs.rs @@ -1351,3 +1351,20 @@ pub fn view_check_eth_invalid_empty_address_from_string( ) { address_from_string("0x00"); } + +pub fn func_activate(ctx: &ScFuncContext, f: &ActivateContext) { + f.state.active().set_value(true); + let deposit = ctx.allowance().base_tokens(); + let transfer = wasmlib::ScTransfer::base_tokens(deposit); + ctx.transfer_allowed(&ctx.account_id(), &transfer); + let delay = f.params.seconds().value(); + testwasmlib::ScFuncs::deactivate(ctx).func.delay(delay).post(); +} + +pub fn func_deactivate(ctx: &ScFuncContext, f: &DeactivateContext) { + f.state.active().set_value(false); +} + +pub fn view_get_active(ctx: &ScViewContext, f: &GetActiveContext) { + f.results.active().set_value(f.state.active().value()); +} diff --git a/contracts/wasm/testwasmlib/schema.yaml b/contracts/wasm/testwasmlib/schema.yaml index cd336484b4..683656e856 100644 --- a/contracts/wasm/testwasmlib/schema.yaml +++ b/contracts/wasm/testwasmlib/schema.yaml @@ -58,6 +58,7 @@ typedefs: # ################################## state: + active: Bool # basic datatypes, using String arrayOfStringArray: StringArray[] arrayOfStringMap: StringMap[] @@ -75,6 +76,13 @@ state: # ################################## funcs: + activate: + params: + seconds: Uint32 + + deactivate: + access: self # only SC itself can invoke this function + stringMapOfStringArrayAppend: params: name: String @@ -208,6 +216,10 @@ funcs: # ################################## views: + getActive: + results: + active: Bool + stringMapOfStringArrayLength: params: name: String diff --git a/contracts/wasm/testwasmlib/test/testwasmlib_client_test.go b/contracts/wasm/testwasmlib/test/testwasmlib_client_test.go index 55b4f8d1d0..dbcabde756 100644 --- a/contracts/wasm/testwasmlib/test/testwasmlib_client_test.go +++ b/contracts/wasm/testwasmlib/test/testwasmlib_client_test.go @@ -154,6 +154,55 @@ func newClient(t testing.TB, svcClient wasmclient.IClientService, wallet *crypto return ctx } +func TestTimedDeactivation(t *testing.T) { + if !useDisposable && !useCluster { + t.SkipNow() + } + + var ctxCluster *wasmclient.WasmClientContext + if useCluster { + ctxCluster = setupClient(t) + } + + ctx := setupClientLib(t) + require.NoError(t, ctx.Err) + + active := getActive(t, ctx) + require.False(t, active) + + f := testwasmlib.ScFuncs.Activate(ctx) + f.Params.Seconds().SetValue(420) + f.Func.TransferBaseTokens(2_000_000).AllowanceBaseTokens(1_000_000).Post() + require.NoError(t, ctx.Err) + + ctx.WaitRequest() + require.NoError(t, ctx.Err) + + for i := 0; i < 100; i++ { + active = getActive(t, ctx) + seconds := 20 + fmt.Printf("TICK #%d: %v\n", i*seconds, active) + if !active { + break + } + factor := time.Duration(seconds) + if useCluster { + // time marches 10x faster + factor /= 10 + } + time.Sleep(factor * time.Second) + } + + _ = ctxCluster +} + +func getActive(t *testing.T, ctx *wasmclient.WasmClientContext) bool { + a := testwasmlib.ScFuncs.GetActive(ctx) + a.Func.Call() + require.NoError(t, ctx.Err) + return a.Results.Active().Value() +} + func TestClientAccountBalance(t *testing.T) { ctx := setupClient(t) wallet := ctx.CurrentKeyPair() diff --git a/contracts/wasm/testwasmlib/ts/testwasmlibimpl/funcs.ts b/contracts/wasm/testwasmlib/ts/testwasmlibimpl/funcs.ts index d771610b18..4757f8babb 100644 --- a/contracts/wasm/testwasmlib/ts/testwasmlibimpl/funcs.ts +++ b/contracts/wasm/testwasmlib/ts/testwasmlibimpl/funcs.ts @@ -780,3 +780,20 @@ export function viewCheckEthEmptyAddressAndAgentID(ctx: wasmlib.ScViewContext, f export function viewCheckEthInvalidEmptyAddressFromString(ctx: wasmlib.ScViewContext, f: sc.CheckEthInvalidEmptyAddressFromStringContext): void { wasmtypes.addressFromString("0x00"); } + +export function funcActivate(ctx: wasmlib.ScFuncContext, f: sc.ActivateContext): void { + f.state.active().setValue(true); + const deposit = ctx.allowance().baseTokens(); + const transfer = wasmlib.ScTransfer.baseTokens(deposit); + ctx.transferAllowed(ctx.accountID(), transfer); + const delay = f.params.seconds().value(); + sc.ScFuncs.deactivate(ctx).func.delay(delay).post(); +} + +export function funcDeactivate(ctx: wasmlib.ScFuncContext, f: sc.DeactivateContext): void { + f.state.active().setValue(false); +} + +export function viewGetActive(ctx: wasmlib.ScViewContext, f: sc.GetActiveContext): void { + f.results.active().setValue(f.state.active().value()); +} diff --git a/documentation/tutorial-examples/test/solotutorial_bg.wasm b/documentation/tutorial-examples/test/solotutorial_bg.wasm index 2e33d13bee..aeac80beda 100644 Binary files a/documentation/tutorial-examples/test/solotutorial_bg.wasm and b/documentation/tutorial-examples/test/solotutorial_bg.wasm differ diff --git a/packages/authentication/auth_context.go b/packages/authentication/auth_context.go new file mode 100644 index 0000000000..679a42d34c --- /dev/null +++ b/packages/authentication/auth_context.go @@ -0,0 +1,19 @@ +package authentication + +import "github.com/labstack/echo/v4" + +type AuthContext struct { + echo.Context + + scheme string + claims *WaspClaims + name string +} + +func (a *AuthContext) Name() string { + return a.name +} + +func (a *AuthContext) Scheme() string { + return a.scheme +} diff --git a/packages/authentication/basic_auth.go b/packages/authentication/basic_auth.go deleted file mode 100644 index 12338e6bd5..0000000000 --- a/packages/authentication/basic_auth.go +++ /dev/null @@ -1,35 +0,0 @@ -package authentication - -import ( - "fmt" - - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - - "github.com/iotaledger/hive.go/web/basicauth" - "github.com/iotaledger/wasp/packages/users" -) - -func AddBasicAuth(webAPI WebAPI, userManager *users.UserManager) { - webAPI.Use(middleware.BasicAuth(func(username, password string, c echo.Context) (bool, error) { - authContext := c.Get("auth").(*AuthContext) - - user, err := userManager.User(username) - if err != nil { - return false, err - } - - valid, err := basicauth.VerifyPassword([]byte(password), user.PasswordSalt, user.PasswordHash) - if err != nil { - return false, fmt.Errorf("failed to verify password: %w", err) - } - - if !valid { - return false, nil - } - - authContext.name = username - authContext.isAuthenticated = true - return true, nil - })) -} diff --git a/packages/authentication/context.go b/packages/authentication/context.go deleted file mode 100644 index 62496211cc..0000000000 --- a/packages/authentication/context.go +++ /dev/null @@ -1,44 +0,0 @@ -package authentication - -import ( - "github.com/labstack/echo/v4" -) - -type ( - ClaimValidator func(claims *WaspClaims) bool - AccessValidator func(validator ClaimValidator) bool -) - -type AuthContext struct { - echo.Context - - scheme string - isAuthenticated bool - claims *WaspClaims - name string -} - -func (a *AuthContext) Name() string { - return a.name -} - -func (a *AuthContext) IsAuthenticated() bool { - return a.isAuthenticated -} - -func (a *AuthContext) Scheme() string { - return a.scheme -} - -func (a *AuthContext) IsAllowedTo(validator ClaimValidator) bool { - if !a.isAuthenticated { - return false - } - - if a.scheme == AuthJWT { - return validator(a.claims) - } - - // IP Whitelist and Basic Auth will always give access to everything! - return true -} diff --git a/packages/authentication/ip_whitelist.go b/packages/authentication/ip_whitelist.go deleted file mode 100644 index e58944d44e..0000000000 --- a/packages/authentication/ip_whitelist.go +++ /dev/null @@ -1,53 +0,0 @@ -package authentication - -import ( - "net" - "strings" - - "github.com/labstack/echo/v4" -) - -func AddIPWhiteListAuth(webAPI WebAPI, config IPWhiteListAuthConfiguration) { - ipWhiteList := createIPWhiteList(config) - webAPI.Use(protected(ipWhiteList)) -} - -func createIPWhiteList(config IPWhiteListAuthConfiguration) []net.IP { - r := make([]net.IP, 0) - for _, ip := range config.Whitelist { - r = append(r, net.ParseIP(ip)) - } - return r -} - -func isAllowed(ip net.IP, whitelist []net.IP) bool { - if ip.IsLoopback() { - return true - } - for _, whitelistedIP := range whitelist { - if ip.Equal(whitelistedIP) { - return true - } - } - return false -} - -func protected(whitelist []net.IP) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - authContext := c.Get("auth").(*AuthContext) - - parts := strings.Split(c.Request().RemoteAddr, ":") - if len(parts) == 2 { - ip := net.ParseIP(parts[0]) - if ip != nil && isAllowed(ip, whitelist) { - authContext.isAuthenticated = true - return next(c) - } - } - - c.Logger().Infof("Blocking request from %s: %s %s", c.Request().RemoteAddr, c.Request().Method, c.Request().RequestURI) - return echo.ErrUnauthorized - } - } -} diff --git a/packages/authentication/jwt_auth.go b/packages/authentication/jwt_auth.go index c280125ca6..4f02baa94b 100644 --- a/packages/authentication/jwt_auth.go +++ b/packages/authentication/jwt_auth.go @@ -4,16 +4,12 @@ import ( "crypto/subtle" "fmt" "net/http" - "strings" "time" "github.com/golang-jwt/jwt/v5" - echojwt "github.com/labstack/echo-jwt/v4" "github.com/labstack/echo/v4" - "github.com/iotaledger/wasp/packages/authentication/shared" "github.com/iotaledger/wasp/packages/authentication/shared/permissions" - "github.com/iotaledger/wasp/packages/users" ) // Errors @@ -32,8 +28,6 @@ type JWTAuth struct { secret []byte } -type MiddlewareValidator = func(c echo.Context, authContext *AuthContext) bool - func NewJWTAuth(duration time.Duration, nodeID string, secret []byte) *JWTAuth { return &JWTAuth{ duration: duration, @@ -42,6 +36,32 @@ func NewJWTAuth(duration time.Duration, nodeID string, secret []byte) *JWTAuth { } } +func (j *JWTAuth) IssueJWT(username string, claims *WaspClaims) (string, error) { + now := time.Now() + + // Set claims + registeredClaims := jwt.RegisteredClaims{ + Subject: username, + Issuer: j.nodeID, + Audience: jwt.ClaimStrings{j.nodeID}, + ID: fmt.Sprintf("%d", now.Unix()), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + } + + if j.duration > 0 { + registeredClaims.ExpiresAt = jwt.NewNumericDate(now.Add(j.duration)) + } + + claims.RegisteredClaims = registeredClaims + + // Create token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + // Generate encoded token and send it as response. + return token.SignedString(j.secret) +} + type WaspClaims struct { jwt.RegisteredClaims Permissions map[string]struct{} `json:"permissions"` @@ -78,117 +98,3 @@ func (c *WaspClaims) compare(field, expected string) bool { func (c *WaspClaims) VerifySubject(expected string) bool { return c.compare(c.Subject, expected) } - -func (j *JWTAuth) IssueJWT(username string, authClaims *WaspClaims) (string, error) { - now := time.Now() - - // Set claims - registeredClaims := jwt.RegisteredClaims{ - Subject: username, - Issuer: j.nodeID, - Audience: jwt.ClaimStrings{j.nodeID}, - ID: fmt.Sprintf("%d", now.Unix()), - IssuedAt: jwt.NewNumericDate(now), - NotBefore: jwt.NewNumericDate(now), - } - - if j.duration > 0 { - registeredClaims.ExpiresAt = jwt.NewNumericDate(now.Add(j.duration)) - } - - authClaims.RegisteredClaims = registeredClaims - - // Create token - token := jwt.NewWithClaims(jwt.SigningMethodHS256, authClaims) - - // Generate encoded token and send it as response. - return token.SignedString(j.secret) -} - -var DefaultJWTDuration time.Duration - -func AddJWTAuth(config JWTAuthConfiguration, privateKey []byte, userManager *users.UserManager, claimValidator ClaimValidator) (*JWTAuth, func() echo.MiddlewareFunc) { - duration := config.Duration - - // If durationHours is 0, we set 24h as the default duration - if duration == 0 { - duration = DefaultJWTDuration - } - - // FIXME: replace "wasp" as nodeID - jwtAuth := NewJWTAuth(duration, "wasp", privateKey) - - authMiddleware := func() echo.MiddlewareFunc { - return echojwt.WithConfig(echojwt.Config{ - ContextKey: JWTContextKey, - NewClaimsFunc: func(c echo.Context) jwt.Claims { - return &WaspClaims{} - }, - Skipper: func(c echo.Context) bool { - path := c.Request().URL.Path - if path == "/" || - strings.HasSuffix(path, shared.AuthRoute()) || - strings.HasSuffix(path, shared.AuthInfoRoute()) || - strings.HasPrefix(path, "/doc") { - return true - } - - return false - }, - SigningKey: jwtAuth.secret, - TokenLookup: "header:Authorization:Bearer ,cookie:jwt", - ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { - keyFunc := func(t *jwt.Token) (interface{}, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - - return jwtAuth.secret, nil - } - - token, err := jwt.ParseWithClaims( - auth, - &WaspClaims{}, - keyFunc, - jwt.WithValidMethods([]string{"HS256"}), - ) - if err != nil { - return nil, err - } - if !token.Valid { - return nil, fmt.Errorf("invalid token") - } - - claims, ok := token.Claims.(*WaspClaims) - if !ok { - return nil, fmt.Errorf("wrong JWT claim type") - } - - audience, err := claims.GetAudience() - if err != nil { - return nil, err - } - b, err := audience.MarshalJSON() - if err != nil { - return nil, err - } - if subtle.ConstantTimeCompare(b, []byte(fmt.Sprintf("[%q]", jwtAuth.nodeID))) == 0 { - return nil, fmt.Errorf("not in audience") - } - - userMap := userManager.Users() - if _, ok := userMap[claims.Subject]; !ok { - return nil, fmt.Errorf("invalid subject") - } - - authContext := c.Get("auth").(*AuthContext) - authContext.isAuthenticated = true - authContext.claims = claims - - return token, nil - }, - }) - } - - return jwtAuth, authMiddleware -} diff --git a/packages/authentication/jwt_auth_test.go b/packages/authentication/jwt_auth_test.go index 35e6070742..80b3d22614 100644 --- a/packages/authentication/jwt_auth_test.go +++ b/packages/authentication/jwt_auth_test.go @@ -16,7 +16,7 @@ import ( "github.com/iotaledger/wasp/packages/users" ) -func TestAddJWTAuth(t *testing.T) { +func TestGetJWTAuthMiddleware(t *testing.T) { t.Run("normal", func(t *testing.T) { e := echo.New() e.GET("/test-route", func(c echo.Context) error { @@ -32,11 +32,10 @@ func TestAddJWTAuth(t *testing.T) { Name: "wasp", }) - _, middleware := authentication.AddJWTAuth( + _, middleware := authentication.GetJWTAuthMiddleware( authentication.JWTAuthConfiguration{}, []byte("abc"), userManager, - nil, // remove claim validator ) e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -45,7 +44,7 @@ func TestAddJWTAuth(t *testing.T) { return next(c) } }) - e.Use(middleware()) + e.Use(middleware) req := httptest.NewRequest(http.MethodGet, "/test-route", http.NoBody) req.Header.Set(echo.HeaderAuthorization, "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ3YXNwIiwic3ViIjoid2FzcCIsImF1ZCI6WyJ3YXNwIl0sImV4cCI6NDg0NTUwNjQ5MiwibmJmIjoxNjg5ODYxNDM2LCJpYXQiOjE2ODk4NjE0MzYsImp0aSI6IjE2ODk4NjE0MzYiLCJwZXJtaXNzaW9ucyI6eyJ3cml0ZSI6e319fQ.VP--725H3xO2Spz6L9twB6Tsm37a26IXVU87cSqRoOM") @@ -73,13 +72,12 @@ func TestAddJWTAuth(t *testing.T) { }) } - _, middleware := authentication.AddJWTAuth( + _, middleware := authentication.GetJWTAuthMiddleware( authentication.JWTAuthConfiguration{}, []byte(""), &users.UserManager{}, - nil, // remove claim validator ) - e.Use(middleware()) + e.Use(middleware) for _, path := range skipPaths { req := httptest.NewRequest(http.MethodGet, path, http.NoBody) @@ -119,11 +117,10 @@ func TestJWTAuthIssueAndVerify(t *testing.T) { Name: username, }) - _, middleware := authentication.AddJWTAuth( + _, middleware := authentication.GetJWTAuthMiddleware( authentication.JWTAuthConfiguration{Duration: duration}, privateKey, userManager, - nil, // remove claim validator ) e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -132,7 +129,7 @@ func TestJWTAuthIssueAndVerify(t *testing.T) { return next(c) } }) - e.Use(middleware()) + e.Use(middleware) req := httptest.NewRequest(http.MethodGet, "/test-route", http.NoBody) req.Header.Set(echo.HeaderAuthorization, fmt.Sprintf("Bearer %s", jwtString)) diff --git a/packages/authentication/jwt_handler.go b/packages/authentication/jwt_handler.go deleted file mode 100644 index cb7618f2c0..0000000000 --- a/packages/authentication/jwt_handler.go +++ /dev/null @@ -1,106 +0,0 @@ -package authentication - -import ( - "errors" - "fmt" - "net/http" - "time" - - "github.com/labstack/echo/v4" - - "github.com/iotaledger/hive.go/web/basicauth" - "github.com/iotaledger/wasp/packages/authentication/shared" - "github.com/iotaledger/wasp/packages/users" -) - -const headerXForwardedPrefix = "X-Forwarded-Prefix" - -type AuthHandler struct { - Jwt *JWTAuth - UserManager *users.UserManager -} - -func (a *AuthHandler) validateLogin(user *users.User, password string) bool { - valid, err := basicauth.VerifyPassword([]byte(password), user.PasswordSalt, user.PasswordHash) - if err != nil { - return false - } - - return valid -} - -func (a *AuthHandler) stageAuthRequest(c echo.Context) (string, error) { - request := &shared.LoginRequest{} - - if err := c.Bind(request); err != nil { - return "", errors.New("invalid form data") - } - - user, err := a.UserManager.User(request.Username) - if err != nil { - return "", errors.New("invalid credentials") - } - - if !a.validateLogin(user, request.Password) { - return "", errors.New("invalid credentials") - } - - claims := &WaspClaims{ - Permissions: user.Permissions, - } - - token, err := a.Jwt.IssueJWT(request.Username, claims) - if err != nil { - return "", errors.New("unable to login") - } - - return token, nil -} - -func (a *AuthHandler) handleJSONAuthRequest(c echo.Context, token string, errorResult error) error { - if errorResult != nil { - return c.JSON(http.StatusUnauthorized, shared.LoginResponse{Error: errorResult}) - } - - return c.JSON(http.StatusOK, shared.LoginResponse{JWT: token}) -} - -func (a *AuthHandler) redirect(c echo.Context, uri string) error { - return c.Redirect(http.StatusFound, c.Request().Header.Get(headerXForwardedPrefix)+uri) -} - -func (a *AuthHandler) handleFormAuthRequest(c echo.Context, token string, errorResult error) error { - if errorResult != nil { - // TODO: Add sessions to get rid of the query parameter? - return a.redirect(c, fmt.Sprintf("%s?error=%s", shared.AuthRoute(), errorResult)) - } - - cookie := http.Cookie{ - Name: "jwt", - Value: token, - HttpOnly: true, // JWT Token will be stored in a http only cookie, this is important to mitigate XSS/XSRF attacks - Expires: time.Now().Add(a.Jwt.duration), - Path: "/", - SameSite: http.SameSiteStrictMode, - } - - c.SetCookie(&cookie) - - return a.redirect(c, shared.AuthRouteSuccess()) -} - -func (a *AuthHandler) CrossAPIAuthHandler(c echo.Context) error { - token, errorResult := a.stageAuthRequest(c) - - contentType := c.Request().Header.Get(echo.HeaderContentType) - - if contentType == echo.MIMEApplicationJSON { - return a.handleJSONAuthRequest(c, token, errorResult) - } - - if contentType == echo.MIMEApplicationForm { - return a.handleFormAuthRequest(c, token, errorResult) - } - - return errors.New("invalid login request") -} diff --git a/packages/authentication/jwt_login.go b/packages/authentication/jwt_login.go new file mode 100644 index 0000000000..700de01cf7 --- /dev/null +++ b/packages/authentication/jwt_login.go @@ -0,0 +1,67 @@ +package authentication + +import ( + "errors" + "fmt" + "net/http" + + "github.com/labstack/echo/v4" + + "github.com/iotaledger/hive.go/web/basicauth" + "github.com/iotaledger/wasp/packages/authentication/shared" + "github.com/iotaledger/wasp/packages/users" +) + +type AuthHandler struct { + Jwt *JWTAuth + UserManager *users.UserManager +} + +func (a *AuthHandler) JWTLoginHandler(c echo.Context) error { + if c.Request().Header.Get(echo.HeaderContentType) != echo.MIMEApplicationJSON { + return errors.New("invalid login request") + } + + req, user, err := a.parseAuthRequest(c) + if err != nil { + return c.JSON(http.StatusUnauthorized, shared.LoginResponse{Error: err}) + } + + claims := &WaspClaims{ + Permissions: user.Permissions, + } + token, err := a.Jwt.IssueJWT(req.Username, claims) + if err != nil { + return c.JSON(http.StatusUnauthorized, shared.LoginResponse{Error: fmt.Errorf("unable to login")}) + } + + return c.JSON(http.StatusOK, shared.LoginResponse{JWT: token}) +} + +func (a *AuthHandler) parseAuthRequest(c echo.Context) (*shared.LoginRequest, *users.User, error) { + request := &shared.LoginRequest{} + + if err := c.Bind(request); err != nil { + return nil, nil, fmt.Errorf("invalid form data") + } + + user, err := a.UserManager.User(request.Username) + if err != nil { + return nil, nil, fmt.Errorf("invalid credentials") + } + + if !validatePassword(user, request.Password) { + return nil, nil, fmt.Errorf("invalid credentials") + } + + return request, user, nil +} + +func validatePassword(user *users.User, password string) bool { + valid, err := basicauth.VerifyPassword([]byte(password), user.PasswordSalt, user.PasswordHash) + if err != nil { + return false + } + + return valid +} diff --git a/packages/authentication/routes.go b/packages/authentication/routes.go new file mode 100644 index 0000000000..d509f340e2 --- /dev/null +++ b/packages/authentication/routes.go @@ -0,0 +1,108 @@ +package authentication + +import ( + "fmt" + "net/http" + "time" + + "github.com/labstack/echo/v4" + "github.com/pangpanglabs/echoswagger/v2" + + "github.com/iotaledger/wasp/packages/authentication/shared" + "github.com/iotaledger/wasp/packages/registry" + "github.com/iotaledger/wasp/packages/users" + "github.com/iotaledger/wasp/packages/webapi/interfaces" +) + +const ( + AuthNone = "none" + AuthJWT = "jwt" +) + +type JWTAuthConfiguration struct { + Duration time.Duration `default:"24h" usage:"jwt token lifetime"` +} + +type AuthConfiguration struct { + Scheme string `default:"ip" usage:"selects which authentication to choose"` + + JWTConfig JWTAuthConfiguration `name:"jwt" usage:"defines the jwt configuration"` +} + +type WebAPI interface { + GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route + Use(middleware ...echo.MiddlewareFunc) +} + +func AddAuthentication( + apiRoot echoswagger.ApiRoot, + userManager *users.UserManager, + nodeIdentityProvider registry.NodeIdentityProvider, + authConfig AuthConfiguration, + mocker interfaces.Mocker, +) echo.MiddlewareFunc { + echoRoot := apiRoot.Echo() + authGroup := apiRoot.Group("auth", "") + + // initialize AuthContext obj as var in echo.Context + echoRoot.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("auth", &AuthContext{ + scheme: authConfig.Scheme, + }) + + return next(c) + } + }) + + // set AuthInfo route + authGroup.GET(shared.AuthInfoRoute(), authInfoHandler(authConfig)). + AddResponse(http.StatusOK, "Login was successful", mocker.Get(shared.AuthInfoModel{}), nil). + SetOperationId("authInfo"). + SetSummary("Get information about the current authentication mode") + + // set Auth route + var middleware echo.MiddlewareFunc + var handler echo.HandlerFunc + switch authConfig.Scheme { + case AuthJWT: + var jwtAuth *JWTAuth + privateKey := nodeIdentityProvider.NodeIdentity().GetPrivateKey().AsBytes() + + // The primary claim is the one mandatory claim that gives access to api/webapi/alike + jwtAuth, middleware = GetJWTAuthMiddleware(authConfig.JWTConfig, privateKey, userManager) + authHandler := &AuthHandler{Jwt: jwtAuth, UserManager: userManager} + handler = authHandler.JWTLoginHandler + + case AuthNone: + middleware = GetNoneAuthMiddleware() + handler = nil + + default: + panic(fmt.Sprintf("Unknown auth scheme %s", authConfig.Scheme)) + } + + authGroup.POST(shared.AuthRoute(), handler). + AddParamBody(mocker.Get(shared.LoginRequest{}), "", "The login request", true). + AddResponse(http.StatusUnauthorized, "Unauthorized (Wrong permissions, missing token)", nil, nil). + AddResponse(http.StatusMethodNotAllowed, "auth type: none", nil, nil). + AddResponse(http.StatusOK, "Login was successful", mocker.Get(shared.LoginResponse{}), nil). + SetOperationId("authenticate"). + SetSummary("Authenticate towards the node") + return middleware +} + +func authInfoHandler(authConfig AuthConfiguration) func(c echo.Context) error { + return func(c echo.Context) error { + model := shared.AuthInfoModel{ + Scheme: authConfig.Scheme, + } + + if model.Scheme == AuthJWT { + model.AuthURL = shared.AuthRoute() + } + + return c.JSON(http.StatusOK, model) + } +} diff --git a/packages/authentication/shared/routes.go b/packages/authentication/shared/routes.go index 25dffdbb62..f71811051d 100644 --- a/packages/authentication/shared/routes.go +++ b/packages/authentication/shared/routes.go @@ -4,10 +4,6 @@ func AuthRoute() string { return "/auth" } -func AuthRouteSuccess() string { - return "/auth/success" -} - func AuthInfoRoute() string { return "/auth/info" } diff --git a/packages/authentication/status.go b/packages/authentication/status.go deleted file mode 100644 index 046269882a..0000000000 --- a/packages/authentication/status.go +++ /dev/null @@ -1,33 +0,0 @@ -package authentication - -import ( - "net/http" - - "github.com/labstack/echo/v4" - - "github.com/iotaledger/wasp/packages/authentication/shared" -) - -type StatusWebAPIModel struct { - config AuthConfiguration -} - -func (a *StatusWebAPIModel) handleAuthenticationStatus(c echo.Context) error { - model := shared.AuthInfoModel{ - Scheme: a.config.Scheme, - } - - if model.Scheme == AuthJWT { - model.AuthURL = shared.AuthRoute() - } - - return c.JSON(http.StatusOK, model) -} - -func addAuthenticationStatus(webAPI WebAPI, config AuthConfiguration) { - c := &StatusWebAPIModel{ - config: config, - } - - webAPI.GET(shared.AuthInfoRoute(), c.handleAuthenticationStatus) -} diff --git a/packages/authentication/strategy.go b/packages/authentication/strategy.go deleted file mode 100644 index 0f61e0fea9..0000000000 --- a/packages/authentication/strategy.go +++ /dev/null @@ -1,171 +0,0 @@ -package authentication - -import ( - "fmt" - "net/http" - "time" - - "github.com/labstack/echo/v4" - "github.com/pangpanglabs/echoswagger/v2" - - "github.com/iotaledger/wasp/packages/authentication/shared" - "github.com/iotaledger/wasp/packages/registry" - "github.com/iotaledger/wasp/packages/users" -) - -const ( - AuthJWT = "jwt" - AuthBasic = "basic" - AuthIPWhitelist = "ip" - AuthNone = "none" -) - -type JWTAuthConfiguration struct { - Duration time.Duration `default:"24h" usage:"jwt token lifetime"` -} - -type BasicAuthConfiguration struct { - Username string `default:"wasp" usage:"the username which grants access to the service"` -} - -type IPWhiteListAuthConfiguration struct { - Whitelist []string `default:"0.0.0.0" usage:"a list of ips that are allowed to access the service"` -} - -type AuthConfiguration struct { - Scheme string `default:"ip" usage:"selects which authentication to choose"` - - JWTConfig JWTAuthConfiguration `name:"jwt" usage:"defines the jwt configuration"` - BasicAuthConfig BasicAuthConfiguration `name:"basic" usage:"defines the basic auth configuration"` - IPWhitelistConfig IPWhiteListAuthConfiguration `name:"ip" usage:"defines the whitelist configuration"` -} - -type WebAPI interface { - GET(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - POST(path string, h echo.HandlerFunc, m ...echo.MiddlewareFunc) *echo.Route - Use(middleware ...echo.MiddlewareFunc) -} - -func AddNoneAuth(webAPI WebAPI) { - // Adds a middleware to set the authContext to authenticated. - // All routes will be open to everyone, so use it in private environments only. - // Handle with care! - noneFunc := func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - authContext := c.Get("auth").(*AuthContext) - - authContext.isAuthenticated = true - - return next(c) - } - } - - webAPI.Use(noneFunc) -} - -func AddV1Authentication( - webAPI WebAPI, - userManager *users.UserManager, - nodeIdentityProvider registry.NodeIdentityProvider, - authConfig AuthConfiguration, - claimValidator ClaimValidator, -) { - addAuthContext(webAPI, authConfig) - - switch authConfig.Scheme { - case AuthBasic: - AddBasicAuth(webAPI, userManager) - case AuthJWT: - nodeIdentity := nodeIdentityProvider.NodeIdentity() - privateKey := nodeIdentity.GetPrivateKey().AsBytes() - - // The primary claim is the one mandatory claim that gives access to api/webapi/alike - jwtAuth, authMiddleware := AddJWTAuth(authConfig.JWTConfig, privateKey, userManager, claimValidator) - - authHandler := &AuthHandler{Jwt: jwtAuth, UserManager: userManager} - webAPI.POST(shared.AuthRoute(), authHandler.CrossAPIAuthHandler) - webAPI.Use(authMiddleware()) - - case AuthIPWhitelist: - AddIPWhiteListAuth(webAPI, authConfig.IPWhitelistConfig) - - case AuthNone: - AddNoneAuth(webAPI) - - default: - panic(fmt.Sprintf("Unknown auth scheme %s", authConfig.Scheme)) - } - - addAuthenticationStatus(webAPI, authConfig) -} - -// TODO: After deprecating V1 we can slim down this whole strategy handler. -// It is currently needed as the current authentication scheme does not support echoSwagger, -// which leaves authentication out of the client code generator. -// After v1 gets removed: -// * Get rid off basic/ip auth and only keeping 'none' and 'JWT' -// * Properly document the routes with echoSwagger -// * Keep only one AddAuthentication method - -func AddV2Authentication(apiRoot echoswagger.ApiRoot, - userManager *users.UserManager, - nodeIdentityProvider registry.NodeIdentityProvider, - authConfig AuthConfiguration, - claimValidator ClaimValidator, -) func() echo.MiddlewareFunc { - echoRoot := apiRoot.Echo() - authGroup := apiRoot.Group("auth", "") - - addAuthContext(echoRoot, authConfig) - - c := &StatusWebAPIModel{ - config: authConfig, - } - - authGroup.GET(shared.AuthInfoRoute(), c.handleAuthenticationStatus). - AddResponse(http.StatusOK, "Login was successful", shared.AuthInfoModel{}, nil). - SetOperationId("authInfo"). - SetSummary("Get information about the current authentication mode") - - switch authConfig.Scheme { - case AuthJWT: - nodeIdentity := nodeIdentityProvider.NodeIdentity() - privateKey := nodeIdentity.GetPrivateKey().AsBytes() - - // The primary claim is the one mandatory claim that gives access to api/webapi/alike - jwtAuth, jwtMiddleware := AddJWTAuth(authConfig.JWTConfig, privateKey, userManager, claimValidator) - - authHandler := &AuthHandler{Jwt: jwtAuth, UserManager: userManager} - authGroup.POST(shared.AuthRoute(), authHandler.CrossAPIAuthHandler). - AddParamBody(shared.LoginRequest{}, "", "The login request", true). - AddResponse(http.StatusUnauthorized, "Unauthorized (Wrong permissions, missing token)", nil, nil). - AddResponse(http.StatusOK, "Login was successful", shared.LoginResponse{}, nil). - SetOperationId("authenticate"). - SetSummary("Authenticate towards the node") - - return jwtMiddleware - - case AuthNone: - AddNoneAuth(echoRoot) - authGroup.POST(shared.AuthRoute(), nil). - AddResponse(http.StatusMethodNotAllowed, "auth type: none", nil, nil) - return nil - - default: - panic(fmt.Sprintf("Unknown auth scheme %s", authConfig.Scheme)) - } -} - -func addAuthContext(webAPI WebAPI, config AuthConfiguration) { - webAPI.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - cc := &AuthContext{ - scheme: config.Scheme, - } - - c.Set("auth", cc) - - return next(c) - } - }) -} diff --git a/packages/authentication/validate_middleware.go b/packages/authentication/validate_middleware.go new file mode 100644 index 0000000000..47ef38be51 --- /dev/null +++ b/packages/authentication/validate_middleware.go @@ -0,0 +1,114 @@ +package authentication + +import ( + "crypto/subtle" + "fmt" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + echojwt "github.com/labstack/echo-jwt/v4" + "github.com/labstack/echo/v4" + + "github.com/iotaledger/wasp/packages/authentication/shared" + "github.com/iotaledger/wasp/packages/users" +) + +var DefaultJWTDuration time.Duration + +func GetJWTAuthMiddleware( + config JWTAuthConfiguration, + privateKey []byte, + userManager *users.UserManager, +) (*JWTAuth, echo.MiddlewareFunc) { + duration := config.Duration + // If durationHours is 0, we set 24h as the default duration + if duration == 0 { + duration = DefaultJWTDuration + } + + // FIXME: replace "wasp" as nodeID + jwtAuth := NewJWTAuth(duration, "wasp", privateKey) + + authMiddleware := echojwt.WithConfig(echojwt.Config{ + ContextKey: JWTContextKey, + NewClaimsFunc: func(c echo.Context) jwt.Claims { + return &WaspClaims{} + }, + Skipper: func(c echo.Context) bool { + path := c.Request().URL.Path + if path == "/" || + strings.HasSuffix(path, shared.AuthRoute()) || + strings.HasSuffix(path, shared.AuthInfoRoute()) || + strings.HasPrefix(path, "/doc") { + return true + } + + return false + }, + SigningKey: jwtAuth.secret, + TokenLookup: "header:Authorization:Bearer ,cookie:jwt", + ParseTokenFunc: func(c echo.Context, auth string) (interface{}, error) { + keyFunc := func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + + return jwtAuth.secret, nil + } + + token, err := jwt.ParseWithClaims( + auth, + &WaspClaims{}, + keyFunc, + jwt.WithValidMethods([]string{"HS256"}), + ) + if err != nil { + return nil, err + } + if !token.Valid { + return nil, fmt.Errorf("invalid token") + } + + claims, ok := token.Claims.(*WaspClaims) + if !ok { + return nil, fmt.Errorf("wrong JWT claim type") + } + + audience, err := claims.GetAudience() + if err != nil { + return nil, err + } + b, err := audience.MarshalJSON() + if err != nil { + return nil, err + } + if subtle.ConstantTimeCompare(b, []byte(fmt.Sprintf("[%q]", jwtAuth.nodeID))) == 0 { + return nil, fmt.Errorf("not in audience") + } + + userMap := userManager.Users() + if _, ok := userMap[claims.Subject]; !ok { + return nil, fmt.Errorf("invalid subject") + } + + authContext := c.Get("auth").(*AuthContext) + authContext.claims = claims + + return token, nil + }, + }) + + return jwtAuth, authMiddleware +} + +func GetNoneAuthMiddleware() echo.MiddlewareFunc { + // Adds a middleware to set the authContext to authenticated. + // All routes will be open to everyone, so use it in private environments only. + // Handle with care! + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + return next(c) + } + } +} diff --git a/packages/authentication/validate_permissions.go b/packages/authentication/validate_permissions.go index 2c5364b1fa..88c8fe4d3b 100644 --- a/packages/authentication/validate_permissions.go +++ b/packages/authentication/validate_permissions.go @@ -28,10 +28,6 @@ func ValidatePermissions(permissions []string) func(next echo.HandlerFunc) echo. return next(e) } - if !authContext.IsAuthenticated() { - return e.JSON(http.StatusUnauthorized, ValidationError{Error: "Invalid token"}) - } - for _, permission := range permissions { if !authContext.claims.HasPermission(permission) { return e.JSON(http.StatusUnauthorized, ValidationError{MissingPermission: permission, Error: "Missing permission"}) diff --git a/packages/chain/mempool/mempool.go b/packages/chain/mempool/mempool.go index 1ddbfb00aa..1885e4f7b3 100644 --- a/packages/chain/mempool/mempool.go +++ b/packages/chain/mempool/mempool.go @@ -48,7 +48,6 @@ import ( "time" "github.com/samber/lo" - "golang.org/x/exp/slices" "github.com/iotaledger/hive.go/logger" consGR "github.com/iotaledger/wasp/packages/chain/cons/cons_gr" @@ -129,7 +128,7 @@ type mempoolImpl struct { tangleTime time.Time timePool TimePool onLedgerPool RequestPool[isc.OnLedgerRequest] - offLedgerPool RequestPool[isc.OffLedgerRequest] + offLedgerPool *TypedPoolByNonce[isc.OffLedgerRequest] distSync gpa.GPA chainHeadAO *isc.AliasOutputWithID chainHeadState state.State @@ -214,7 +213,7 @@ func New( tangleTime: time.Time{}, timePool: NewTimePool(metrics.SetTimePoolSize, log.Named("TIM")), onLedgerPool: NewTypedPool[isc.OnLedgerRequest](waitReq, metrics.SetOnLedgerPoolSize, metrics.SetOnLedgerReqTime, log.Named("ONL")), - offLedgerPool: NewTypedPool[isc.OffLedgerRequest](waitReq, metrics.SetOffLedgerPoolSize, metrics.SetOffLedgerReqTime, log.Named("OFF")), + offLedgerPool: NewTypedPoolByNonce[isc.OffLedgerRequest](waitReq, metrics.SetOffLedgerPoolSize, metrics.SetOffLedgerReqTime, log.Named("OFF")), chainHeadAO: nil, serverNodesUpdatedPipe: pipe.NewInfinitePipe[*reqServerNodesUpdated](), serverNodes: []*cryptolib.PublicKey{}, @@ -480,11 +479,11 @@ func (mpi *mempoolImpl) shouldAddOffledgerRequest(req isc.OffLedgerRequest) erro return fmt.Errorf("bad nonce, expected: %d", accountNonce) } - governanceState := governance.NewStateAccess(mpi.chainHeadState) // check user has on-chain balance accountsState := accounts.NewStateAccess(mpi.chainHeadState) if !accountsState.AccountExists(req.SenderAccount()) { // make an exception for gov calls (sender is chan owner and target is gov contract) + governanceState := governance.NewStateAccess(mpi.chainHeadState) chainOwner := governanceState.ChainOwnerID() isGovRequest := req.SenderAccount().Equals(chainOwner) && req.CallTarget().Contract == governance.Contract.Hname() if !isGovRequest { @@ -530,17 +529,12 @@ func (mpi *mempoolImpl) handleConsensusProposal(recv *reqConsensusProposal) { mpi.handleConsensusProposalForChainHead(recv) } -type reqRefNonce struct { - ref *isc.RequestRef - nonce uint64 -} - func (mpi *mempoolImpl) refsToPropose() []*isc.RequestRef { // // The case for matching ChainHeadAO and request BaseAO reqRefs := []*isc.RequestRef{} if !mpi.tangleTime.IsZero() { // Wait for tangle-time to process the on ledger requests. - mpi.onLedgerPool.Filter(func(request isc.OnLedgerRequest, ts time.Time) bool { + mpi.onLedgerPool.Filter(func(request isc.OnLedgerRequest, _ time.Time) bool { if isc.RequestIsExpired(request, mpi.tangleTime) { return false // Drop it from the mempool } @@ -551,53 +545,37 @@ func (mpi *mempoolImpl) refsToPropose() []*isc.RequestRef { }) } - expectedAccountNonces := map[string]uint64{} // string is isc.AgentID.String() - requestsNonces := map[string][]reqRefNonce{} // string is isc.AgentID.String() - - mpi.offLedgerPool.Filter(func(request isc.OffLedgerRequest, ts time.Time) bool { - ref := isc.RequestRefFromRequest(request) - reqRefs = append(reqRefs, ref) - - // collect the nonces for each account - senderKey := request.SenderAccount().String() - _, ok := expectedAccountNonces[senderKey] - if !ok { - // get the current state nonce so we can detect gaps with it - expectedAccountNonces[senderKey] = mpi.nonce(request.SenderAccount()) + mpi.offLedgerPool.Iterate(func(account string, entries []*OrderedPoolEntry[isc.OffLedgerRequest]) { + agentID, err := isc.AgentIDFromString(account) + if err != nil { + panic(fmt.Errorf("invalid agentID string: %s", err.Error())) } - requestsNonces[senderKey] = append(requestsNonces[senderKey], reqRefNonce{ref: ref, nonce: request.Nonce()}) - - return true // Keep them for now - }) - - // remove any gaps in the nonces of each account - { - doNotPropose := []*isc.RequestRef{} - for account, refNonces := range requestsNonces { - // sort by nonce - slices.SortFunc(refNonces, func(a, b reqRefNonce) bool { - return a.nonce < b.nonce - }) - // check for gaps with the state nonce - if expectedAccountNonces[account] != refNonces[0].nonce { - // if the first one doesn't match the nonce required from the state, don't propose any of the following - for _, ref := range refNonces { - doNotPropose = append(doNotPropose, ref.ref) - } + accountNonce := mpi.nonce(agentID) + for _, e := range entries { + reqNonce := e.req.Nonce() + if reqNonce < accountNonce { + // nonce too old, delete + mpi.log.Debugf("refsToPropose, account: %s, removing request (%s) with old nonce (%d) from the pool", account, e.req.ID(), e.req.Nonce()) + mpi.offLedgerPool.Remove(e.req) + continue + } + if e.old { + // this request was marked as "old", do not propose it + mpi.log.Debugf("refsToPropose, account: %s, skipping old request: %s", account, e.req.ID().String()) continue } - // check for gaps within the request list - for i := 1; i < len(refNonces); i++ { - if refNonces[i].nonce != refNonces[i-1].nonce+1 { - doNotPropose = append(doNotPropose, refNonces[i].ref) - } + if reqNonce == accountNonce { + // expected nonce, add it to the list to propose + mpi.log.Debugf("refsToPropose, account: %s, proposing reqID %s with nonce: %d", account, e.req.ID().String(), e.req.Nonce()) + reqRefs = append(reqRefs, isc.RequestRefFromRequest(e.req)) + accountNonce++ // increment the account nonce to match the next valid request + } + if reqNonce > accountNonce { + mpi.log.Debugf("refsToPropose, account: %s, req %s has a nonce %d which is too high (expected %d), won't be proposed", account, e.req.ID().String(), e.req.Nonce(), accountNonce) + return // no more valid nonces for this account, continue to the next account } } - // remove undesirable requests from the proposal - reqRefs = lo.Filter(reqRefs, func(x *isc.RequestRef, _ int) bool { - return !slices.Contains(doNotPropose, x) - }) - } + }) return reqRefs } diff --git a/packages/chain/mempool/mempool_test.go b/packages/chain/mempool/mempool_test.go index c39f2bb82c..9d2fe4586e 100644 --- a/packages/chain/mempool/mempool_test.go +++ b/packages/chain/mempool/mempool_test.go @@ -603,6 +603,74 @@ func TestMempoolsNonceGaps(t *testing.T) { // nonce 10 was never proposed } +func TestMempoolOverrideNonce(t *testing.T) { + // 1 node setup + // send nonce 0 + // send another request with the same nonce 0 + // assert the last request is proposed + te := newEnv(t, 1, 0, true) + defer te.close() + + tangleTime := time.Now() + for _, node := range te.mempools { + node.ServerNodesUpdated(te.peerPubKeys, te.peerPubKeys) + node.TangleTimeUpdated(tangleTime) + } + awaitTrackHeadChannels := make([]<-chan bool, len(te.mempools)) + // deposit some funds so off-ledger requests can go through + t.Log("TrackNewChainHead") + for i, node := range te.mempools { + awaitTrackHeadChannels[i] = node.TrackNewChainHead(te.stateForAO(i, te.originAO), nil, te.originAO, []state.Block{}, []state.Block{}) + } + for i := range te.mempools { + <-awaitTrackHeadChannels[i] + } + + output := transaction.BasicOutputFromPostData( + te.governor.Address(), + isc.HnameNil, + isc.RequestParameters{ + TargetAddress: te.chainID.AsAddress(), + Assets: isc.NewAssetsBaseTokens(10 * isc.Million), + }, + ) + onLedgerReq, err := isc.OnLedgerFromUTXO(output, tpkg.RandOutputID(uint16(0))) + require.NoError(t, err) + for _, node := range te.mempools { + node.ReceiveOnLedgerRequest(onLedgerReq) + } + currentAO := blockFn(te, []isc.Request{onLedgerReq}, te.originAO, tangleTime) + + initialReq := isc.NewOffLedgerRequest( + isc.RandomChainID(), + isc.Hn("foo"), + isc.Hn("bar"), + dict.New(), + 0, + gas.LimitsDefault.MaxGasPerRequest, + ).Sign(te.governor) + + require.NoError(t, te.mempools[0].ReceiveOffLedgerRequest(initialReq)) + time.Sleep(200 * time.Millisecond) // give some time for the requests to reach the pool + + overwritingReq := isc.NewOffLedgerRequest( + isc.RandomChainID(), + isc.Hn("baz"), + isc.Hn("bar"), + dict.New(), + 0, + gas.LimitsDefault.MaxGasPerRequest, + ).Sign(te.governor) + + require.NoError(t, te.mempools[0].ReceiveOffLedgerRequest(overwritingReq)) + time.Sleep(200 * time.Millisecond) // give some time for the requests to reach the pool + reqRefs := <-te.mempools[0].ConsensusProposalAsync(te.ctx, currentAO) + proposedReqs := <-te.mempools[0].ConsensusRequestsAsync(te.ctx, reqRefs) + require.Len(t, proposedReqs, 1) + require.Equal(t, overwritingReq, proposedReqs[0]) + require.NotEqual(t, initialReq, proposedReqs[0]) +} + //////////////////////////////////////////////////////////////////////////////// // testEnv diff --git a/packages/chain/mempool/typed_pool_by_nonce.go b/packages/chain/mempool/typed_pool_by_nonce.go new file mode 100644 index 0000000000..a115f9426e --- /dev/null +++ b/packages/chain/mempool/typed_pool_by_nonce.go @@ -0,0 +1,169 @@ +// Copyright 2020 IOTA Stiftung +// SPDX-License-Identifier: Apache-2.0 + +package mempool + +import ( + "fmt" + "time" + + "golang.org/x/exp/slices" + + "github.com/iotaledger/hive.go/ds/shrinkingmap" + "github.com/iotaledger/hive.go/logger" + "github.com/iotaledger/wasp/packages/isc" +) + +// keeps a map of requests ordered by nonce for each account +type TypedPoolByNonce[V isc.OffLedgerRequest] struct { + waitReq WaitReq + refLUT *shrinkingmap.ShrinkingMap[isc.RequestRefKey, *OrderedPoolEntry[V]] + // reqsByAcountOrdered keeps an ordered map of reqsByAcountOrdered for each account by nonce + reqsByAcountOrdered *shrinkingmap.ShrinkingMap[string, []*OrderedPoolEntry[V]] // string is isc.AgentID.String() + sizeMetric func(int) + timeMetric func(time.Duration) + log *logger.Logger +} + +var _ RequestPool[isc.OffLedgerRequest] = &TypedPoolByNonce[isc.OffLedgerRequest]{} + +func NewTypedPoolByNonce[V isc.OffLedgerRequest](waitReq WaitReq, sizeMetric func(int), timeMetric func(time.Duration), log *logger.Logger) *TypedPoolByNonce[V] { + return &TypedPoolByNonce[V]{ + waitReq: waitReq, + reqsByAcountOrdered: shrinkingmap.New[string, []*OrderedPoolEntry[V]](), + refLUT: shrinkingmap.New[isc.RequestRefKey, *OrderedPoolEntry[V]](), + sizeMetric: sizeMetric, + timeMetric: timeMetric, + log: log, + } +} + +type OrderedPoolEntry[V isc.OffLedgerRequest] struct { + req V + old bool + ts time.Time +} + +func (p *TypedPoolByNonce[V]) Has(reqRef *isc.RequestRef) bool { + return p.refLUT.Has(reqRef.AsKey()) +} + +func (p *TypedPoolByNonce[V]) Get(reqRef *isc.RequestRef) V { + entry, exists := p.refLUT.Get(reqRef.AsKey()) + if !exists { + return *new(V) + } + return entry.req +} + +func (p *TypedPoolByNonce[V]) Add(request V) { + ref := isc.RequestRefFromRequest(request) + entry := &OrderedPoolEntry[V]{req: request, ts: time.Now()} + account := request.SenderAccount().String() + + if !p.refLUT.Set(ref.AsKey(), entry) { + p.log.Debugf("NOT ADDED, already exists. reqID: %v as key=%v, senderAccount: ", request.ID(), ref, account) + return // not added already exists + } + + defer func() { + p.log.Debugf("ADD %v as key=%v, senderAccount: ", request.ID(), ref, account) + p.sizeMetric(p.refLUT.Size()) + p.waitReq.MarkAvailable(request) + }() + + reqsForAcount, exists := p.reqsByAcountOrdered.Get(account) + if !exists { + // no other requests for this account + p.reqsByAcountOrdered.Set(account, []*OrderedPoolEntry[V]{entry}) + return + } + + // add to the account requests, keep the slice ordered + + // find the index where the new entry should be added + index, exists := slices.BinarySearchFunc(reqsForAcount, entry, + func(a, b *OrderedPoolEntry[V]) int { + aNonce := a.req.Nonce() + bNonce := b.req.Nonce() + if aNonce == bNonce { + return 0 + } + if aNonce > bNonce { + return 1 + } + return -1 + }, + ) + if exists { + // same nonce, mark the existing request with overlapping nonce as "old", place the new one + // NOTE: do not delete the request here, as it might already be part of an on-going consensus round + reqsForAcount[index].old = true + } + + reqsForAcount = append(reqsForAcount, entry) // add to the end of the list (thus extending the array) + + // make room if target position is not at the end + if index != len(reqsForAcount)+1 { + copy(reqsForAcount[index+1:], reqsForAcount[index:]) + reqsForAcount[index] = entry + } + p.reqsByAcountOrdered.Set(account, reqsForAcount) +} + +func (p *TypedPoolByNonce[V]) Remove(request V) { + refKey := isc.RequestRefFromRequest(request).AsKey() + entry, exists := p.refLUT.Get(refKey) + if !exists { + return // does not exist + } + defer func() { + p.sizeMetric(p.refLUT.Size()) + p.timeMetric(time.Since(entry.ts)) + }() + if p.refLUT.Delete(refKey) { + p.log.Debugf("DEL %v as key=%v", request.ID(), refKey) + } + account := entry.req.SenderAccount().String() + reqsByAccount, exists := p.reqsByAcountOrdered.Get(account) + if !exists { + p.log.Error("inconsistency trying to DEL %v as key=%v, no request list for account %s", request.ID(), refKey, account) + return + } + // find the request in the accounts map + indexToDel := slices.IndexFunc(reqsByAccount, func(e *OrderedPoolEntry[V]) bool { + return refKey == isc.RequestRefFromRequest(e.req).AsKey() + }) + if indexToDel == -1 { + p.log.Error("inconsistency trying to DEL %v as key=%v, request not found in list for account %s", request.ID(), refKey, account) + return + } + if len(reqsByAccount) == 1 { // just remove the entire array for the account + p.reqsByAcountOrdered.Delete(account) + return + } + reqsByAccount[indexToDel] = nil // remove the pointer reference to allow GC of the entry object + reqsByAccount = slices.Delete(reqsByAccount, indexToDel, indexToDel+1) + p.reqsByAcountOrdered.Set(account, reqsByAccount) +} + +func (p *TypedPoolByNonce[V]) Iterate(f func(account string, requests []*OrderedPoolEntry[V])) { + p.reqsByAcountOrdered.ForEach(func(acc string, entries []*OrderedPoolEntry[V]) bool { + f(acc, slices.Clone(entries)) + return true + }) +} + +func (p *TypedPoolByNonce[V]) Filter(predicate func(request V, ts time.Time) bool) { + p.refLUT.ForEach(func(refKey isc.RequestRefKey, entry *OrderedPoolEntry[V]) bool { + if !predicate(entry.req, entry.ts) { + p.Remove(entry.req) + } + return true + }) + p.sizeMetric(p.refLUT.Size()) +} + +func (p *TypedPoolByNonce[V]) StatusString() string { + return fmt.Sprintf("{|req|=%d}", p.refLUT.Size()) +} diff --git a/packages/chain/mempool/typed_pool_by_nonce_test.go b/packages/chain/mempool/typed_pool_by_nonce_test.go new file mode 100644 index 0000000000..85ef62afd8 --- /dev/null +++ b/packages/chain/mempool/typed_pool_by_nonce_test.go @@ -0,0 +1,51 @@ +package mempool + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/iotaledger/wasp/packages/isc" + "github.com/iotaledger/wasp/packages/testutil" + "github.com/iotaledger/wasp/packages/testutil/testkey" + "github.com/iotaledger/wasp/packages/testutil/testlogger" +) + +func TestSomething(t *testing.T) { + waitReq := NewWaitReq(waitRequestCleanupEvery) + pool := NewTypedPoolByNonce[isc.OffLedgerRequest](waitReq, func(int) {}, func(time.Duration) {}, testlogger.NewSilentLogger("", true)) + + // generate a bunch of requests for the same account + kp, addr := testkey.GenKeyAddr() + agentID := isc.NewAgentID(addr) + + req0 := testutil.DummyOffledgerRequestForAccount(isc.RandomChainID(), 0, kp) + req1 := testutil.DummyOffledgerRequestForAccount(isc.RandomChainID(), 1, kp) + req2 := testutil.DummyOffledgerRequestForAccount(isc.RandomChainID(), 2, kp) + req2new := testutil.DummyOffledgerRequestForAccount(isc.RandomChainID(), 2, kp) + pool.Add(req0) + pool.Add(req1) + pool.Add(req1) // try to add the same request many times + pool.Add(req2) + pool.Add(req1) + require.EqualValues(t, 3, pool.refLUT.Size()) + require.EqualValues(t, 1, pool.reqsByAcountOrdered.Size()) + reqsInPoolForAccount, _ := pool.reqsByAcountOrdered.Get(agentID.String()) + require.Len(t, reqsInPoolForAccount, 3) + pool.Add(req2new) + pool.Add(req2new) + require.EqualValues(t, 4, pool.refLUT.Size()) + require.EqualValues(t, 1, pool.reqsByAcountOrdered.Size()) + reqsInPoolForAccount, _ = pool.reqsByAcountOrdered.Get(agentID.String()) + require.Len(t, reqsInPoolForAccount, 4) + + // try to remove everything during iteration + pool.Iterate(func(account string, entries []*OrderedPoolEntry[isc.OffLedgerRequest]) { + for _, e := range entries { + pool.Remove(e.req) + } + }) + require.EqualValues(t, 0, pool.refLUT.Size()) + require.EqualValues(t, 0, pool.reqsByAcountOrdered.Size()) +} diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go index 25c214875f..fabe0398d7 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_cache.go @@ -89,7 +89,7 @@ func (bcT *blockCache) GetBlock(commitment *state.L1Commitment) state.Block { if bcT.wal.Contains(commitment.BlockHash()) { block, err := bcT.wal.Read(commitment.BlockHash()) if err != nil { - bcT.log.Errorf("Error reading block %s from WAL: %w", commitment, err) + bcT.log.Errorf("Error reading block index %v %s from WAL: %w", block.StateIndex(), commitment, err) return nil } bcT.addBlockToCache(block) diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go index 2dca39e4a4..7944764814 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal.go @@ -1,8 +1,9 @@ package sm_gpa_utils import ( - "bufio" + "encoding/hex" "fmt" + "io" "os" "path/filepath" "sort" @@ -15,6 +16,7 @@ import ( "github.com/iotaledger/wasp/packages/isc" "github.com/iotaledger/wasp/packages/metrics" "github.com/iotaledger/wasp/packages/state" + "github.com/iotaledger/wasp/packages/util/rwutil" ) type blockWAL struct { @@ -24,7 +26,10 @@ type blockWAL struct { metrics *metrics.ChainBlockWALMetrics } -const constBlockWALFileSuffix = ".blk" +const ( + constBlockWALFileSuffix = ".blk" + constBlockWALTmpFileSuffix = ".tmp" +) func NewBlockWAL(log *logger.Logger, baseDir string, chainID isc.ChainID, metrics *metrics.ChainBlockWALMetrics) (BlockWAL, error) { dir := filepath.Join(baseDir, chainID.String()) @@ -42,40 +47,89 @@ func NewBlockWAL(log *logger.Logger, baseDir string, chainID isc.ChainID, metric } // Overwrites, if block is already in WAL +// Block format (version 1): +// - Version (4 bytes, unsigned int); value 1 +// - State index (4 bytes, unsigned int) +// - Block bytes +// +// Block format (legacy = version 0): +// - Block bytes func (bwT *blockWAL) Write(block state.Block) error { blockIndex := block.StateIndex() commitment := block.L1Commitment() - fileName := blockWALFileName(commitment.BlockHash()) - filePath := filepath.Join(bwT.dir, fileName) - f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o666) - if err != nil { - bwT.metrics.IncFailedWrites() - return fmt.Errorf("opening file %s for writing block index %v failed: %w", fileName, blockIndex, err) + subfolderName := blockWALSubFolderName(commitment.BlockHash()) + folderPath := filepath.Join(bwT.dir, subfolderName) + if err := ioutils.CreateDirectory(folderPath, 0o777); err != nil { + return fmt.Errorf("failed to create folder %s for writing block: %w", folderPath, err) } - defer f.Close() - blockBytes := block.Bytes() - n, err := f.Write(blockBytes) + tmpFileName := blockWALTmpFileName(commitment.BlockHash()) + tmpFilePath := filepath.Join(folderPath, tmpFileName) + err := func() error { // Function is used to make defered close occur when it is needed even if write is successful + f, err := os.OpenFile(tmpFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o666) + if err != nil { + bwT.metrics.IncFailedWrites() + return fmt.Errorf("failed to create temporary file %s for writing block: %w", tmpFilePath, err) + } + defer f.Close() + ww := rwutil.NewWriter(f) + ww.WriteUint32(1) // Version; 4 bytes (instead of just 1) to lower number of possible collisions with legacy WAL format + ww.WriteUint32(blockIndex) + if ww.Err != nil { + bwT.metrics.IncFailedWrites() + return fmt.Errorf("failed to write block index into temporary file %s: %w", tmpFilePath, ww.Err) + } + err = block.Write(f) + if err != nil { + bwT.metrics.IncFailedWrites() + return fmt.Errorf("writing block to temporary file %s failed: %w", tmpFilePath, err) + } + return nil + }() if err != nil { - bwT.metrics.IncFailedWrites() - return fmt.Errorf("writing block index %v data to file %s failed: %w", blockIndex, fileName, err) + return err } - if len(blockBytes) != n { - bwT.metrics.IncFailedWrites() - return fmt.Errorf("only %v of total %v bytes of block index %v were written to file %s", n, len(blockBytes), blockIndex, fileName) + finalFileName := blockWALFileName(commitment.BlockHash()) + finalFilePath := filepath.Join(folderPath, finalFileName) + err = os.Rename(tmpFilePath, finalFilePath) + if err != nil { + return fmt.Errorf("failed to move temporary WAL file %s to permanent location %s: %v", + tmpFilePath, finalFilePath, err) } + bwT.metrics.BlockWritten(block.StateIndex()) - bwT.LogDebugf("Block index %v %s written to wal; file name - %s", blockIndex, commitment, fileName) + bwT.LogDebugf("Block index %v %s written to wal; file name - %s", blockIndex, commitment, finalFilePath) return nil } +func (bwT *blockWAL) blockFilepath(blockHash state.BlockHash) (string, bool) { + subfolderName := blockWALSubFolderName(blockHash) + fileName := blockWALFileName(blockHash) + + pathWithSubFolder := filepath.Join(bwT.dir, subfolderName, fileName) + _, err := os.Stat(pathWithSubFolder) + if err == nil { + return pathWithSubFolder, true + } + + // Checked for backward compatibility and for ease of adding some blocks from other sources + pathNoSubFolder := filepath.Join(bwT.dir, fileName) + _, err = os.Stat(pathNoSubFolder) + if err == nil { + return pathNoSubFolder, true + } + return "", false +} + func (bwT *blockWAL) Contains(blockHash state.BlockHash) bool { - _, err := os.Stat(filepath.Join(bwT.dir, blockWALFileName(blockHash))) - return err == nil + _, exists := bwT.blockFilepath(blockHash) + return exists } func (bwT *blockWAL) Read(blockHash state.BlockHash) (state.Block, error) { - fileName := blockWALFileName(blockHash) - filePath := filepath.Join(bwT.dir, fileName) + filePath, exists := bwT.blockFilepath(blockHash) + if !exists { + return nil, fmt.Errorf("block hash %s is not present in WAL", blockHash) + } block, err := blockFromFilePath(filePath) if err != nil { bwT.metrics.IncFailedReads() @@ -88,26 +142,17 @@ func (bwT *blockWAL) Read(blockHash state.BlockHash) (state.Block, error) { // The blocks are provided ordered by the state index, so that they can be applied to the store. // This function reads blocks twice, but tries to minimize the amount of memory required to load the WAL. func (bwT *blockWAL) ReadAllByStateIndex(cb func(stateIndex uint32, block state.Block) bool) error { - dirEntries, err := os.ReadDir(bwT.dir) - if err != nil { - return err - } blocksByStateIndex := map[uint32][]string{} - for _, dirEntry := range dirEntries { - if !dirEntry.Type().IsRegular() { - continue + checkFile := func(filePath string) { + if !strings.HasSuffix(filePath, constBlockWALFileSuffix) { + return } - if !strings.HasSuffix(dirEntry.Name(), constBlockWALFileSuffix) { - continue - } - filePath := filepath.Join(bwT.dir, dirEntry.Name()) - fileBlock, fileErr := blockFromFilePath(filePath) - if fileErr != nil { + stateIndex, err := blockIndexFromFilePath(filePath) + if err != nil { bwT.metrics.IncFailedReads() bwT.LogWarn("Unable to read %v: %v", filePath, err) - continue + return } - stateIndex := fileBlock.StateIndex() stateIndexPaths, found := blocksByStateIndex[stateIndex] if found { stateIndexPaths = append(stateIndexPaths, filePath) @@ -116,6 +161,28 @@ func (bwT *blockWAL) ReadAllByStateIndex(cb func(stateIndex uint32, block state. } blocksByStateIndex[stateIndex] = stateIndexPaths } + + var checkDir func(dirPath string, dirEntries []os.DirEntry) + checkDir = func(dirPath string, dirEntries []os.DirEntry) { + for _, dirEntry := range dirEntries { + entryPath := filepath.Join(dirPath, dirEntry.Name()) + if dirEntry.IsDir() { + subDirEntries, err := os.ReadDir(entryPath) + if err == nil { + checkDir(entryPath, subDirEntries) + } + } else { + checkFile(entryPath) + } + } + } + + dirEntries, err := os.ReadDir(bwT.dir) + if err != nil { + return err + } + checkDir(bwT.dir, dirEntries) + allStateIndexes := lo.Keys(blocksByStateIndex) sort.Slice(allStateIndexes, func(i, j int) bool { return allStateIndexes[i] < allStateIndexes[j] }) for _, stateIndex := range allStateIndexes { @@ -135,31 +202,100 @@ func (bwT *blockWAL) ReadAllByStateIndex(cb func(stateIndex uint32, block state. return nil } -func blockFromFilePath(filePath string) (state.Block, error) { +func blockInfoFromFilePath[I any](filePath string, getInfoFun func(uint32, io.Reader) (I, error)) (I, error) { f, err := os.OpenFile(filePath, os.O_RDONLY, 0o666) + var info I if err != nil { - return nil, fmt.Errorf("opening file %s for reading failed: %w", filePath, err) + return info, fmt.Errorf("opening file %s for reading failed: %w", filePath, err) } defer f.Close() - stat, err := f.Stat() - if err != nil { - return nil, fmt.Errorf("reading file %s information failed: %w", filePath, err) + rr := rwutil.NewReader(f) + version := rr.ReadUint32() + if rr.Err != nil { + return info, fmt.Errorf("failed reading file version: %w", rr.Err) + } + var errV error + if version == 1 { + info, errV = getInfoFun(version, f) + if errV == nil { + return info, nil + } + // error reading as version 1, maybe it's legacy version? } - blockBytes := make([]byte, stat.Size()) - n, err := bufio.NewReader(f).Read(blockBytes) + // backwards compatibility - reading legacy version + // NOTE: reopening file, because version bytes (or possibly more) has already been read + f, err = os.OpenFile(filePath, os.O_RDONLY, 0o666) if err != nil { - return nil, fmt.Errorf("reading file %s failed: %w", filePath, err) + return info, fmt.Errorf("reopening file %s for reading failed: %w", filePath, err) } - if int64(n) != stat.Size() { - return nil, fmt.Errorf("only %v of total %v bytes of file %s were read", n, stat.Size(), filePath) + defer f.Close() + info, err = getInfoFun(0, f) + if errV == nil { + return info, err } - block, err := state.BlockFromBytes(blockBytes) - if err != nil { - return nil, fmt.Errorf("error parsing block from bytes read from file %s: %w", filePath, err) + return info, fmt.Errorf("version %v error: %w, legacy version error: %w", version, errV, err) +} + +func blockIndexFromFilePath(filePath string) (uint32, error) { + return blockInfoFromFilePath(filePath, blockIndexFromReader) +} + +func blockFromFilePath(filePath string) (state.Block, error) { + return blockInfoFromFilePath(filePath, blockFromReader) +} + +func blockIndexFromReader(version uint32, r io.Reader) (uint32, error) { + switch version { + case 1: + rr := rwutil.NewReader(r) + index := rr.ReadUint32() + return index, rr.Err + case 0: + block := state.NewBlock() + err := block.Read(r) + if err != nil { + return 0, err + } + return block.StateIndex(), nil + default: + return 0, fmt.Errorf("unknown block version %v", version) + } +} + +func blockFromReader(version uint32, r io.Reader) (state.Block, error) { + switch version { + case 1: + blockIndex, err := blockIndexFromReader(version, r) + if err != nil { + return nil, fmt.Errorf("failed to read block index in header: %w", err) + } + block := state.NewBlock() + err = block.Read(r) + if err != nil { + return nil, fmt.Errorf("failed to read block: %w", err) + } + if blockIndex != block.StateIndex() { + return nil, fmt.Errorf("block index in header %v does not match block index in block %v", + blockIndex, block.StateIndex()) + } + return block, nil + case 0: + block := state.NewBlock() + err := block.Read(r) + return block, err + default: + return nil, fmt.Errorf("unknown block version %v", version) } - return block, nil +} + +func blockWALSubFolderName(blockHash state.BlockHash) string { + return hex.EncodeToString(blockHash[:1]) } func blockWALFileName(blockHash state.BlockHash) string { return blockHash.String() + constBlockWALFileSuffix } + +func blockWALTmpFileName(blockHash state.BlockHash) string { + return blockWALFileName(blockHash) + constBlockWALTmpFileSuffix +} diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go index be7c42cc8c..659a8e0174 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_rapid_test.go @@ -3,7 +3,6 @@ package sm_gpa_utils import ( "crypto/rand" "os" - "path/filepath" "testing" "github.com/samber/lo" @@ -44,7 +43,7 @@ func newBlockWALTestSM(t *rapid.T) *blockWALTestSM { return bwtsmT } -func (bwtsmT *blockWALTestSM) Cleanup() { +func (bwtsmT *blockWALTestSM) cleanup() { bwtsmT.log.Sync() os.RemoveAll(constTestFolder) } @@ -107,8 +106,8 @@ func (bwtsmT *blockWALTestSM) MoveBlock(t *rapid.T) { if blockHashOrig.Equals(blockHashToDamage) { t.Skip() } - fileOrigPath := bwtsmT.pathFromHash(blockHashOrig) - fileToDamagePath := bwtsmT.pathFromHash(blockHashToDamage) + fileOrigPath := walPathFromHash(bwtsmT.factory.GetChainID(), blockHashOrig) + fileToDamagePath := walPathFromHash(bwtsmT.factory.GetChainID(), blockHashToDamage) data, err := os.ReadFile(fileOrigPath) require.NoError(t, err) err = os.WriteFile(fileToDamagePath, data, 0o644) @@ -124,7 +123,7 @@ func (bwtsmT *blockWALTestSM) DamageBlock(t *rapid.T) { t.Skip() } blockHash := rapid.SampledFrom(blockHashes).Example() - filePath := bwtsmT.pathFromHash(blockHash) + filePath := walPathFromHash(bwtsmT.factory.GetChainID(), blockHash) data := make([]byte, 50) _, err := rand.Read(data) require.NoError(t, err) @@ -188,10 +187,6 @@ func (bwtsmT *blockWALTestSM) getGoodBlockHashes() []state.BlockHash { return result } -func (bwtsmT *blockWALTestSM) pathFromHash(blockHash state.BlockHash) string { - return filepath.Join(constTestFolder, bwtsmT.factory.GetChainID().String(), blockWALFileName(blockHash)) -} - func (bwtsmT *blockWALTestSM) invariantAllWrittenBlocksExist(t *rapid.T) { for blockHash := range bwtsmT.blocks { require.True(t, bwtsmT.bw.Contains(blockHash)) @@ -202,6 +197,7 @@ func TestBlockWALPropBased(t *testing.T) { rapid.Check(t, func(t *rapid.T) { sm := newBlockWALTestSM(t) t.Repeat(rapid.StateMachineActions(sm)) + sm.cleanup() }) } diff --git a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go index 81ea15a778..28be65f7ec 100644 --- a/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go +++ b/packages/chain/statemanager/sm_gpa/sm_gpa_utils/block_wal_test.go @@ -49,6 +49,52 @@ func TestBlockWALBasic(t *testing.T) { require.Error(t, err) } +// Check if block prior to version 1 is read (that has no version data) +func TestBlockWALLegacy(t *testing.T) { + log := testlogger.NewLogger(t) + defer log.Sync() + defer cleanupAfterTest(t) + + factory := NewBlockFactory(t) + blocks := factory.GetBlocks(4, 1) + wal, err := NewBlockWAL(log, constTestFolder, factory.GetChainID(), mockBlockWALMetrics()) + require.NoError(t, err) + writeBlocksLegacy(t, factory.GetChainID(), blocks) + for i := range blocks { + block, err := wal.Read(blocks[i].Hash()) + require.NoError(t, err) + CheckBlocksEqual(t, blocks[i], block) + } +} + +// Check if existing block in WAL is found even if it is not in a subfolder +func TestBlockWALNoSubfolder(t *testing.T) { + log := testlogger.NewLogger(t) + defer log.Sync() + defer cleanupAfterTest(t) + + factory := NewBlockFactory(t) + blocks := factory.GetBlocks(4, 1) + wal, err := NewBlockWAL(log, constTestFolder, factory.GetChainID(), mockBlockWALMetrics()) + require.NoError(t, err) + for i := range blocks { + err = wal.Write(blocks[i]) + require.NoError(t, err) + } + for _, block := range blocks { + pathWithSubfolder := walPathFromHash(factory.GetChainID(), block.Hash()) + pathNoSubfolder := walPathNoSubfolderFromHash(factory.GetChainID(), block.Hash()) + err = os.Rename(pathWithSubfolder, pathNoSubfolder) + require.NoError(t, err) + } + for _, block := range blocks { + require.True(t, wal.Contains(block.Hash())) + blockRead, err := wal.Read(block.Hash()) + require.NoError(t, err) + CheckBlocksEqual(t, block, blockRead) + } +} + // Check if existing WAL record is overwritten func TestBlockWALOverwrite(t *testing.T) { log := testlogger.NewLogger(t) @@ -63,11 +109,8 @@ func TestBlockWALOverwrite(t *testing.T) { err = wal.Write(blocks[i]) require.NoError(t, err) } - pathFromHashFun := func(blockHash state.BlockHash) string { - return filepath.Join(constTestFolder, factory.GetChainID().String(), blockWALFileName(blockHash)) - } - file0Path := pathFromHashFun(blocks[0].Hash()) - file1Path := pathFromHashFun(blocks[1].Hash()) + file0Path := walPathFromHash(factory.GetChainID(), blocks[0].Hash()) + file1Path := walPathFromHash(factory.GetChainID(), blocks[1].Hash()) err = os.Rename(file1Path, file0Path) require.NoError(t, err) // block[1] is no longer in WAL @@ -116,6 +159,83 @@ func TestBlockWALRestart(t *testing.T) { } } +func testReadAllByStateIndex(t *testing.T, addToWALFun func(isc.ChainID, BlockWAL, []state.Block)) { + log := testlogger.NewLogger(t) + defer log.Sync() + defer cleanupAfterTest(t) + + factory := NewBlockFactory(t) + mainBlocks := 50 + branchBlocks := 20 + branchBlockIndex := mainBlocks - branchBlocks - 1 + blocksMain := factory.GetBlocks(mainBlocks, 1) + blocksBranch := factory.GetBlocksFrom(branchBlocks, 1, blocksMain[branchBlockIndex].L1Commitment(), 2) + wal, err := NewBlockWAL(log, constTestFolder, factory.GetChainID(), mockBlockWALMetrics()) + require.NoError(t, err) + addToWALFun(factory.GetChainID(), wal, blocksMain) + addToWALFun(factory.GetChainID(), wal, blocksBranch) + + var blocksRead []state.Block + err = wal.ReadAllByStateIndex(func(stateIndex uint32, block state.Block) bool { + require.Equal(t, stateIndex, block.StateIndex()) + blocksRead = append(blocksRead, block) + return true + }) + require.NoError(t, err) + + for i := 0; i <= branchBlockIndex; i++ { + require.Equal(t, uint32(i+1), blocksRead[i].StateIndex()) + CheckBlocksEqual(t, blocksMain[i], blocksRead[i]) + } + for i := branchBlockIndex + 1; i < mainBlocks; i++ { + blocksReadIndex := i*2 - branchBlockIndex - 1 + block1 := blocksRead[blocksReadIndex] + block2 := blocksRead[blocksReadIndex+1] + require.Equal(t, uint32(i+1), block1.StateIndex()) + require.Equal(t, uint32(i+1), block2.StateIndex()) + if !blocksMain[i].L1Commitment().Equals(block1.L1Commitment()) { + block1, block2 = block2, block1 + } + CheckBlocksEqual(t, blocksMain[i], block1) + CheckBlocksEqual(t, blocksBranch[i-branchBlockIndex-1], block2) + } +} + +func TestReadAllByStateIndexV1(t *testing.T) { + testReadAllByStateIndex(t, func(chainID isc.ChainID, wal BlockWAL, blocks []state.Block) { + for _, block := range blocks { + err := wal.Write(block) + require.NoError(t, err) + } + }) +} + +func TestReadAllByStateIndexLegacy(t *testing.T) { + testReadAllByStateIndex(t, func(chainID isc.ChainID, wal BlockWAL, blocks []state.Block) { + writeBlocksLegacy(t, chainID, blocks) + }) +} + +func walPathFromHash(chainID isc.ChainID, blockHash state.BlockHash) string { + return filepath.Join(constTestFolder, chainID.String(), blockWALSubFolderName(blockHash), blockWALFileName(blockHash)) +} + +func walPathNoSubfolderFromHash(chainID isc.ChainID, blockHash state.BlockHash) string { + return filepath.Join(constTestFolder, chainID.String(), blockWALFileName(blockHash)) +} + +func writeBlocksLegacy(t *testing.T, chainID isc.ChainID, blocks []state.Block) { + for _, block := range blocks { + filePath := walPathNoSubfolderFromHash(chainID, block.Hash()) + f, err := os.OpenFile(filePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o666) + require.NoError(t, err) + err = block.Write(f) + require.NoError(t, err) + err = f.Close() + require.NoError(t, err) + } +} + func cleanupAfterTest(t *testing.T) { err := os.RemoveAll(constTestFolder) require.NoError(t, err) diff --git a/packages/chain/statemanager/state_manager.go b/packages/chain/statemanager/state_manager.go index 61d7ab6919..841a3f9858 100644 --- a/packages/chain/statemanager/state_manager.go +++ b/packages/chain/statemanager/state_manager.go @@ -378,15 +378,16 @@ func (smT *stateManager) handleNodePublicKeys(req *reqChainNodesUpdated) { func (smT *stateManager) handlePreliminaryBlock(msg *reqPreliminaryBlock) { if !smT.wal.Contains(msg.block.Hash()) { if err := smT.wal.Write(msg.block); err != nil { - smT.log.Warnf("Preliminary block %v cannot be saved to the WAL: %v", msg.block.L1Commitment(), err) + smT.log.Warnf("Preliminary block index %v %s cannot be saved to the WAL: %v", + msg.block.StateIndex(), msg.block.L1Commitment(), err) msg.Respond(err) return } - smT.log.Warnf("Preliminary block %v saved to the WAL.", msg.block.L1Commitment()) + smT.log.Warnf("Preliminary block index %v %s saved to the WAL.", msg.block.StateIndex(), msg.block.L1Commitment()) msg.Respond(nil) return } - smT.log.Warnf("Preliminary block %v already exist in the WAL.", msg.block.L1Commitment()) + smT.log.Warnf("Preliminary block index %v %s already exist in the WAL.", msg.block.StateIndex(), msg.block.L1Commitment()) msg.Respond(nil) } diff --git a/packages/state/block.go b/packages/state/block.go index dd16dd2c5f..69382a61fd 100644 --- a/packages/state/block.go +++ b/packages/state/block.go @@ -41,7 +41,7 @@ func (b *block) Bytes() []byte { func (b *block) essenceBytes() []byte { ww := rwutil.NewBytesWriter() - ww.WriteFromFunc(b.writeEssence) + b.writeEssence(ww) return ww.Bytes() } @@ -85,20 +85,18 @@ func (b *block) TrieRoot() trie.Hash { func (b *block) Read(r io.Reader) error { rr := rwutil.NewReader(r) rr.ReadN(b.trieRoot[:]) - rr.ReadFromFunc(b.readEssence) + b.readEssence(rr) return rr.Err } func (b *block) Write(w io.Writer) error { ww := rwutil.NewWriter(w) ww.WriteN(b.trieRoot[:]) - ww.WriteFromFunc(b.writeEssence) + b.writeEssence(ww) return ww.Err } -func (b *block) readEssence(r io.Reader) (int, error) { - rr := rwutil.NewReader(r) - counter := rwutil.NewReadCounter(rr) +func (b *block) readEssence(rr *rwutil.Reader) { b.mutations = buffered.NewMutations() rr.Read(b.mutations) hasPrevL1Commitment := rr.ReadBool() @@ -106,17 +104,14 @@ func (b *block) readEssence(r io.Reader) (int, error) { b.previousL1Commitment = new(L1Commitment) rr.Read(b.previousL1Commitment) } - return counter.Count(), rr.Err } -func (b *block) writeEssence(w io.Writer) (int, error) { - ww := rwutil.NewWriter(w) +func (b *block) writeEssence(ww *rwutil.Writer) { ww.Write(b.mutations) ww.WriteBool(b.previousL1Commitment != nil) if b.previousL1Commitment != nil { ww.Write(b.previousL1Commitment) } - return len(ww.Bytes()), ww.Err } // test only function diff --git a/packages/testutil/dummyrequest.go b/packages/testutil/dummyrequest.go index 014a4e0560..697e13b83b 100644 --- a/packages/testutil/dummyrequest.go +++ b/packages/testutil/dummyrequest.go @@ -1,6 +1,7 @@ package testutil import ( + "github.com/iotaledger/wasp/packages/cryptolib" "github.com/iotaledger/wasp/packages/isc" "github.com/iotaledger/wasp/packages/kv/dict" "github.com/iotaledger/wasp/packages/testutil/testkey" @@ -15,3 +16,11 @@ func DummyOffledgerRequest(chainID isc.ChainID) isc.OffLedgerRequest { keys, _ := testkey.GenKeyAddr() return req.Sign(keys) } + +func DummyOffledgerRequestForAccount(chainID isc.ChainID, nonce uint64, kp *cryptolib.KeyPair) isc.OffLedgerRequest { + contract := isc.Hn("somecontract") + entrypoint := isc.Hn("someentrypoint") + args := dict.Dict{} + req := isc.NewOffLedgerRequest(chainID, contract, entrypoint, args, nonce, gas.LimitsDefault.MaxGasPerRequest) + return req.Sign(kp) +} diff --git a/packages/vm/core/testcore/sbtests/sbtestsc/testcore_bg.wasm b/packages/vm/core/testcore/sbtests/sbtestsc/testcore_bg.wasm index d7fea954b6..159c25ab05 100644 Binary files a/packages/vm/core/testcore/sbtests/sbtestsc/testcore_bg.wasm and b/packages/vm/core/testcore/sbtests/sbtestsc/testcore_bg.wasm differ diff --git a/packages/vm/vmimpl/runreq.go b/packages/vm/vmimpl/runreq.go index 890aad73f6..09b8d97f91 100644 --- a/packages/vm/vmimpl/runreq.go +++ b/packages/vm/vmimpl/runreq.go @@ -28,7 +28,7 @@ import ( "github.com/iotaledger/wasp/packages/vm/vmexceptions" ) -// runRequest processes a single isc.Request in the batch +// runRequest processes a single isc.Request in the batch, returning an error means the request will be skipped func (vmctx *vmContext) runRequest(req isc.Request, requestIndex uint16, maintenanceMode bool) ( res *vm.RequestResult, unprocessableToRetry []isc.OnLedgerRequest, diff --git a/packages/vm/vmtxbuilder/foundries.go b/packages/vm/vmtxbuilder/foundries.go index 580a6c602e..7ffde82f9e 100644 --- a/packages/vm/vmtxbuilder/foundries.go +++ b/packages/vm/vmtxbuilder/foundries.go @@ -44,9 +44,9 @@ func (txb *AnchorTransactionBuilder) CreateNewFoundry( } f.Amount = parameters.L1().Protocol.RentStructure.MinRent(f) txb.invokedFoundries[f.SerialNumber] = &foundryInvoked{ - serialNumber: f.SerialNumber, - in: nil, - out: f, + serialNumber: f.SerialNumber, + accountingInput: nil, + accountingOutput: f, } return f.SerialNumber, f.Amount } @@ -60,14 +60,14 @@ func (txb *AnchorTransactionBuilder) ModifyNativeTokenSupply(nativeTokenID iotag panic(vm.ErrFoundryDoesNotExist) } // check if the loaded foundry matches the nativeTokenID - if nativeTokenID != f.in.MustNativeTokenID() { + if nativeTokenID != f.accountingInput.MustNativeTokenID() { panic(fmt.Errorf("%v: requested token ID: %s, foundry token id: %s", - vm.ErrCantModifySupplyOfTheToken, nativeTokenID.String(), f.in.MustNativeTokenID().String())) + vm.ErrCantModifySupplyOfTheToken, nativeTokenID.String(), f.accountingInput.MustNativeTokenID().String())) } defer txb.mustCheckTotalNativeTokensExceeded() - simpleTokenScheme := util.MustTokenScheme(f.out.TokenScheme) + simpleTokenScheme := util.MustTokenScheme(f.accountingOutput.TokenScheme) // check the supply bounds var newMinted, newMelted *big.Int @@ -88,7 +88,7 @@ func (txb *AnchorTransactionBuilder) ModifyNativeTokenSupply(nativeTokenID iotag simpleTokenScheme.MeltedTokens = newMelted txb.invokedFoundries[sn] = f - adjustment += int64(f.in.Amount) - int64(f.out.Amount) + adjustment += int64(f.accountingInput.Amount) - int64(f.accountingOutput.Amount) return adjustment } @@ -103,10 +103,10 @@ func (txb *AnchorTransactionBuilder) ensureFoundry(sn uint32) *foundryInvoked { return nil } f := &foundryInvoked{ - serialNumber: foundryOutput.SerialNumber, - outputID: outputID, - in: foundryOutput, - out: cloneFoundryOutput(foundryOutput), + serialNumber: foundryOutput.SerialNumber, + accountingInputID: outputID, + accountingInput: foundryOutput, + accountingOutput: cloneFoundryOutput(foundryOutput), } txb.invokedFoundries[sn] = f return f @@ -118,14 +118,14 @@ func (txb *AnchorTransactionBuilder) DestroyFoundry(sn uint32) uint64 { if f == nil { panic(vm.ErrFoundryDoesNotExist) } - if f.in == nil { + if f.accountingInput == nil { panic(vm.ErrCantDestroyFoundryBeingCreated) } defer txb.mustCheckTotalNativeTokensExceeded() - f.out = nil - return f.in.Amount + f.accountingOutput = nil + return f.accountingInput.Amount } func (txb *AnchorTransactionBuilder) nextFoundrySerialNumber() uint32 { @@ -172,27 +172,27 @@ func (txb *AnchorTransactionBuilder) FoundriesToBeUpdated() ([]uint32, []uint32) func (txb *AnchorTransactionBuilder) FoundryOutputsBySN(serNums []uint32) map[uint32]*iotago.FoundryOutput { ret := make(map[uint32]*iotago.FoundryOutput) for _, sn := range serNums { - ret[sn] = txb.invokedFoundries[sn].out + ret[sn] = txb.invokedFoundries[sn].accountingOutput } return ret } type foundryInvoked struct { - serialNumber uint32 - outputID iotago.OutputID // if in != nil - in *iotago.FoundryOutput // nil if created - out *iotago.FoundryOutput // nil if destroyed + serialNumber uint32 + accountingInputID iotago.OutputID // if in != nil + accountingInput *iotago.FoundryOutput // nil if created + accountingOutput *iotago.FoundryOutput // nil if destroyed } func (f *foundryInvoked) Clone() *foundryInvoked { outputID := iotago.OutputID{} - copy(outputID[:], f.outputID[:]) + copy(outputID[:], f.accountingInputID[:]) return &foundryInvoked{ - serialNumber: f.serialNumber, - outputID: outputID, - in: cloneFoundryOutput(f.in), - out: cloneFoundryOutput(f.out), + serialNumber: f.serialNumber, + accountingInputID: outputID, + accountingInput: cloneFoundryOutput(f.accountingInput), + accountingOutput: cloneFoundryOutput(f.accountingOutput), } } @@ -201,20 +201,20 @@ func (f *foundryInvoked) isNewCreated() bool { } func (f *foundryInvoked) requiresExistingAccountingUTXOAsInput() bool { - if f.in == nil { + if f.accountingInput == nil { return false } - if identicalFoundries(f.in, f.out) { + if identicalFoundries(f.accountingInput, f.accountingOutput) { return false } return true } func (f *foundryInvoked) producesAccountingOutput() bool { - if f.out == nil { + if f.accountingOutput == nil { return false } - if identicalFoundries(f.in, f.out) { + if identicalFoundries(f.accountingInput, f.accountingOutput) { return false } return true diff --git a/packages/vm/vmtxbuilder/nfts.go b/packages/vm/vmtxbuilder/nfts.go index 9eee321696..579fbd25c4 100644 --- a/packages/vm/vmtxbuilder/nfts.go +++ b/packages/vm/vmtxbuilder/nfts.go @@ -10,11 +10,11 @@ import ( ) type nftIncluded struct { - ID iotago.NFTID - outputID iotago.OutputID // only available when the input is already accounted for (NFT was deposited in a previous block) - in *iotago.NFTOutput - out *iotago.NFTOutput // this is not the same as in the `nativeTokenBalance` struct, this can be the accounting output, or the output leaving the chain. // TODO should refactor to follow the same logic so its easier to grok - sentOutside bool + ID iotago.NFTID + accountingInputID iotago.OutputID // only available when the input is already accounted for (NFT was deposited in a previous block) + accountingInput *iotago.NFTOutput + resultingOutput *iotago.NFTOutput // this is not the same as in the `nativeTokenBalance` struct, this can be the accounting output, or the output leaving the chain. // TODO should refactor to follow the same logic so its easier to grok + sentOutside bool } // 3 cases of handling NFTs in txbuilder @@ -28,13 +28,13 @@ func (n *nftIncluded) Clone() *nftIncluded { copy(nftID[:], n.ID[:]) outputID := iotago.OutputID{} - copy(outputID[:], n.outputID[:]) + copy(outputID[:], n.accountingInputID[:]) return &nftIncluded{ - ID: nftID, - outputID: outputID, - in: cloneInternalNFTOutputOrNil(n.in), - out: cloneInternalNFTOutputOrNil(n.out), + ID: nftID, + accountingInputID: outputID, + accountingInput: cloneInternalNFTOutputOrNil(n.accountingInput), + resultingOutput: cloneInternalNFTOutputOrNil(n.resultingOutput), } } @@ -61,7 +61,7 @@ func (txb *AnchorTransactionBuilder) NFTOutputs() []*iotago.NFTOutput { for _, nft := range txb.nftsSorted() { if !nft.sentOutside { // outputs sent outside are already added to txb.postedOutputs - outs = append(outs, nft.out) + outs = append(outs, nft.resultingOutput) } } return outs @@ -71,9 +71,9 @@ func (txb *AnchorTransactionBuilder) NFTOutputsToBeUpdated() (toBeAdded, toBeRem toBeAdded = make([]*iotago.NFTOutput, 0, len(txb.nftsIncluded)) toBeRemoved = make([]*iotago.NFTOutput, 0, len(txb.nftsIncluded)) for _, nft := range txb.nftsSorted() { - if nft.in != nil { + if nft.accountingInput != nil { // to remove if input is not nil (nft exists in accounting), and its sent to outside the chain - toBeRemoved = append(toBeRemoved, nft.out) + toBeRemoved = append(toBeRemoved, nft.resultingOutput) continue } if nft.sentOutside { @@ -81,7 +81,7 @@ func (txb *AnchorTransactionBuilder) NFTOutputsToBeUpdated() (toBeAdded, toBeRem continue } // to add if input is nil (doesn't exist in accounting), and its not sent outside the chain - toBeAdded = append(toBeAdded, nft.out) + toBeAdded = append(toBeAdded, nft.resultingOutput) } return toBeAdded, toBeRemoved } @@ -111,10 +111,10 @@ func (txb *AnchorTransactionBuilder) internalNFTOutputFromRequest(nftOutput *iot out.Amount = parameters.L1().Protocol.RentStructure.MinRent(out) ret := &nftIncluded{ - ID: out.NFTID, - in: nil, - out: out, - sentOutside: false, + ID: out.NFTID, + accountingInput: nil, + resultingOutput: out, + sentOutside: false, } txb.nftsIncluded[out.NFTID] = ret @@ -129,8 +129,8 @@ func (txb *AnchorTransactionBuilder) sendNFT(o *iotago.NFTOutput) int64 { if txb.nftsIncluded[o.NFTID] != nil { // NFT comes in and out in the same block txb.nftsIncluded[o.NFTID].sentOutside = true - sd := txb.nftsIncluded[o.NFTID].out.Amount // reimburse the SD cost - txb.nftsIncluded[o.NFTID].out = o + sd := txb.nftsIncluded[o.NFTID].resultingOutput.Amount // reimburse the SD cost + txb.nftsIncluded[o.NFTID].resultingOutput = o return int64(sd) } if txb.InputsAreFull() { @@ -140,11 +140,11 @@ func (txb *AnchorTransactionBuilder) sendNFT(o *iotago.NFTOutput) int64 { // using NFT already owned by the chain in, outputID := txb.accountsView.NFTOutput(o.NFTID) toInclude := &nftIncluded{ - ID: o.NFTID, - in: in, - outputID: outputID, - out: o, - sentOutside: true, + ID: o.NFTID, + accountingInput: in, + accountingInputID: outputID, + resultingOutput: o, + sentOutside: true, } txb.nftsIncluded[o.NFTID] = toInclude diff --git a/packages/vm/vmtxbuilder/tokens.go b/packages/vm/vmtxbuilder/tokens.go index feb538807f..1473752f93 100644 --- a/packages/vm/vmtxbuilder/tokens.go +++ b/packages/vm/vmtxbuilder/tokens.go @@ -15,10 +15,10 @@ import ( // nativeTokenBalance represents on-chain account of the specific native token type nativeTokenBalance struct { - nativeTokenID iotago.NativeTokenID - accountingoutputID iotago.OutputID // if in != nil, otherwise zeroOutputID - in *iotago.BasicOutput // if nil it means output does not exist, this is new account for the token_id - accountingOutput *iotago.BasicOutput // current balance of the token_id on the chain + nativeTokenID iotago.NativeTokenID + accountingInputID iotago.OutputID // if in != nil, otherwise zeroOutputID + accountingInput *iotago.BasicOutput // if nil it means output does not exist, this is new account for the token_id + accountingOutput *iotago.BasicOutput // current balance of the token_id on the chain } func (n *nativeTokenBalance) Clone() *nativeTokenBalance { @@ -26,13 +26,13 @@ func (n *nativeTokenBalance) Clone() *nativeTokenBalance { copy(nativeTokenID[:], n.nativeTokenID[:]) outputID := iotago.OutputID{} - copy(outputID[:], n.accountingoutputID[:]) + copy(outputID[:], n.accountingInputID[:]) return &nativeTokenBalance{ - nativeTokenID: nativeTokenID, - accountingoutputID: outputID, - in: cloneInternalBasicOutputOrNil(n.in), - accountingOutput: cloneInternalBasicOutputOrNil(n.accountingOutput), + nativeTokenID: nativeTokenID, + accountingInputID: outputID, + accountingInput: cloneInternalBasicOutputOrNil(n.accountingInput), + accountingOutput: cloneInternalBasicOutputOrNil(n.accountingOutput), } } @@ -55,7 +55,7 @@ func (n *nativeTokenBalance) requiresExistingAccountingUTXOAsInput() bool { // value didn't change return false } - return n.in != nil + return n.accountingInput != nil } func (n *nativeTokenBalance) getOutValue() *big.Int { @@ -86,23 +86,23 @@ func (n *nativeTokenBalance) updateMinSD() { func (n *nativeTokenBalance) identicalInOut() bool { switch { - case n.in == n.accountingOutput: + case n.accountingInput == n.accountingOutput: panic("identicalBasicOutputs: internal inconsistency 1") - case n.in == nil || n.accountingOutput == nil: + case n.accountingInput == nil || n.accountingOutput == nil: return false - case !n.in.Ident().Equal(n.accountingOutput.Ident()): + case !n.accountingInput.Ident().Equal(n.accountingOutput.Ident()): return false - case n.in.Amount != n.accountingOutput.Amount: + case n.accountingInput.Amount != n.accountingOutput.Amount: return false - case !n.in.NativeTokens.Equal(n.accountingOutput.NativeTokens): + case !n.accountingInput.NativeTokens.Equal(n.accountingOutput.NativeTokens): return false - case !n.in.Features.Equal(n.accountingOutput.Features): + case !n.accountingInput.Features.Equal(n.accountingOutput.Features): return false - case len(n.in.NativeTokens) != 1: + case len(n.accountingInput.NativeTokens) != 1: panic("identicalBasicOutputs: internal inconsistency 2") case len(n.accountingOutput.NativeTokens) != 1: panic("identicalBasicOutputs: internal inconsistency 3") - case n.in.NativeTokens[0].ID != n.nativeTokenID: + case n.accountingInput.NativeTokens[0].ID != n.nativeTokenID: panic("identicalBasicOutputs: internal inconsistency 4") case n.accountingOutput.NativeTokens[0].ID != n.nativeTokenID: panic("identicalBasicOutputs: internal inconsistency 5") @@ -187,11 +187,11 @@ func (txb *AnchorTransactionBuilder) addNativeTokenBalanceDelta(nativeTokenID io if util.IsZeroBigInt(nt.getOutValue()) { // 0 native tokens on the output side - if nt.in == nil { + if nt.accountingInput == nil { // in this case the internar accounting output that would be created is not needed anymore, reiburse the SD return int64(nt.accountingOutput.Amount) } - return int64(nt.in.Amount) + return int64(nt.accountingInput.Amount) } // update the SD in case the storage deposit has changed from the last time this output was used @@ -228,10 +228,10 @@ func (txb *AnchorTransactionBuilder) ensureNativeTokenBalance(nativeTokenID iota } nativeTokenBalance := &nativeTokenBalance{ - nativeTokenID: nativeTokenID, - accountingoutputID: outputID, - in: basicOutputIn, - accountingOutput: basicOutputOut, + nativeTokenID: nativeTokenID, + accountingInputID: outputID, + accountingInput: basicOutputIn, + accountingOutput: basicOutputOut, } txb.balanceNativeTokens[nativeTokenID] = nativeTokenBalance return nativeTokenBalance diff --git a/packages/vm/vmtxbuilder/totals.go b/packages/vm/vmtxbuilder/totals.go index d7fe896d0a..d23843c38e 100644 --- a/packages/vm/vmtxbuilder/totals.go +++ b/packages/vm/vmtxbuilder/totals.go @@ -45,10 +45,10 @@ func (txb *AnchorTransactionBuilder) sumInputs() *TransactionTotals { if !ok { s = new(big.Int) } - s.Add(s, ntb.in.NativeTokens[0].Amount) + s.Add(s, ntb.accountingInput.NativeTokens[0].Amount) totals.NativeTokenBalances[id] = s // sum up storage deposit in inputs of internal UTXOs - totals.TotalBaseTokensInStorageDeposit += ntb.in.Amount + totals.TotalBaseTokensInStorageDeposit += ntb.accountingInput.Amount } // sum up all explicitly consumed outputs, except anchor output for _, out := range txb.consumed { @@ -65,16 +65,16 @@ func (txb *AnchorTransactionBuilder) sumInputs() *TransactionTotals { } for _, f := range txb.invokedFoundries { if f.requiresExistingAccountingUTXOAsInput() { - totals.TotalBaseTokensInStorageDeposit += f.in.Amount - simpleTokenScheme := util.MustTokenScheme(f.in.TokenScheme) - totals.TokenCirculatingSupplies[f.in.MustNativeTokenID()] = new(big.Int). + totals.TotalBaseTokensInStorageDeposit += f.accountingInput.Amount + simpleTokenScheme := util.MustTokenScheme(f.accountingInput.TokenScheme) + totals.TokenCirculatingSupplies[f.accountingInput.MustNativeTokenID()] = new(big.Int). Sub(simpleTokenScheme.MintedTokens, simpleTokenScheme.MeltedTokens) } } for _, nft := range txb.nftsIncluded { - if !isc.IsEmptyOutputID(nft.outputID) { - totals.TotalBaseTokensInStorageDeposit += nft.in.Amount + if !isc.IsEmptyOutputID(nft.accountingInputID) { + totals.TotalBaseTokensInStorageDeposit += nft.accountingInput.Amount } } @@ -111,10 +111,10 @@ func (txb *AnchorTransactionBuilder) sumOutputs() *TransactionTotals { if !f.producesAccountingOutput() { continue } - totals.TotalBaseTokensInStorageDeposit += f.out.Amount - id := f.out.MustNativeTokenID() + totals.TotalBaseTokensInStorageDeposit += f.accountingOutput.Amount + id := f.accountingOutput.MustNativeTokenID() totals.TokenCirculatingSupplies[id] = big.NewInt(0) - simpleTokenScheme := util.MustTokenScheme(f.out.TokenScheme) + simpleTokenScheme := util.MustTokenScheme(f.accountingOutput.TokenScheme) totals.TokenCirculatingSupplies[id].Sub(simpleTokenScheme.MintedTokens, simpleTokenScheme.MeltedTokens) } for _, o := range txb.postedOutputs { @@ -131,7 +131,7 @@ func (txb *AnchorTransactionBuilder) sumOutputs() *TransactionTotals { } for _, nft := range txb.nftsIncluded { if !nft.sentOutside { - totals.TotalBaseTokensInStorageDeposit += nft.out.Amount + totals.TotalBaseTokensInStorageDeposit += nft.resultingOutput.Amount } } return totals diff --git a/packages/vm/vmtxbuilder/txbuilder.go b/packages/vm/vmtxbuilder/txbuilder.go index 51f62f3100..bd9c8acfac 100644 --- a/packages/vm/vmtxbuilder/txbuilder.go +++ b/packages/vm/vmtxbuilder/txbuilder.go @@ -95,10 +95,10 @@ func (txb *AnchorTransactionBuilder) Clone() *AnchorTransactionBuilder { } } -// SplitAssetsIntoInternalOutputs splits the native Tokens/NFT from a given (request) output. +// splitAssetsIntoInternalOutputs splits the native Tokens/NFT from a given (request) output. // returns the resulting outputs and the list of new outputs // (some of the native tokens might already have an accounting output owned by the chain, so we don't need new outputs for those) -func (txb *AnchorTransactionBuilder) SplitAssetsIntoInternalOutputs(req isc.OnLedgerRequest) uint64 { +func (txb *AnchorTransactionBuilder) splitAssetsIntoInternalOutputs(req isc.OnLedgerRequest) uint64 { requiredSD := uint64(0) for _, nativeToken := range req.Assets().NativeTokens { // ensure this NT is in the txbuilder, update it @@ -117,27 +117,32 @@ func (txb *AnchorTransactionBuilder) SplitAssetsIntoInternalOutputs(req isc.OnLe if req.NFT() != nil { // create new output nftIncl := txb.internalNFTOutputFromRequest(req.Output().(*iotago.NFTOutput), req.OutputID()) - requiredSD += nftIncl.out.Amount + requiredSD += nftIncl.resultingOutput.Amount } txb.consumed = append(txb.consumed, req) return requiredSD } +func (txb *AnchorTransactionBuilder) assertLimits() { + if txb.InputsAreFull() { + panic(vmexceptions.ErrInputLimitExceeded) + } + if txb.outputsAreFull() { + panic(vmexceptions.ErrOutputLimitExceeded) + } + txb.mustCheckTotalNativeTokensExceeded() +} + // Consume adds an input to the transaction. // It panics if transaction cannot hold that many inputs // All explicitly consumed inputs will hold fixed index in the transaction // It updates total assets held by the chain. So it may panic due to exceed output counts // Returns the amount of baseTokens needed to cover SD costs for the NTs/NFT contained by the request output func (txb *AnchorTransactionBuilder) Consume(req isc.OnLedgerRequest) uint64 { - if txb.InputsAreFull() { - panic(vmexceptions.ErrInputLimitExceeded) - } - - defer txb.mustCheckTotalNativeTokensExceeded() - + defer txb.assertLimits() // deduct the minSD for all the outputs that need to be created - requiredSD := txb.SplitAssetsIntoInternalOutputs(req) + requiredSD := txb.splitAssetsIntoInternalOutputs(req) return requiredSD } @@ -145,31 +150,16 @@ func (txb *AnchorTransactionBuilder) Consume(req isc.OnLedgerRequest) uint64 { // consumes the original request and cretes a new output keeping assets intact // return the position of the resulting output in `txb.postedOutputs` func (txb *AnchorTransactionBuilder) ConsumeUnprocessable(req isc.OnLedgerRequest) int { - if txb.InputsAreFull() { - panic(vmexceptions.ErrInputLimitExceeded) - } - - if txb.outputsAreFull() { - panic(vmexceptions.ErrOutputLimitExceeded) - } - - defer txb.mustCheckTotalNativeTokensExceeded() - + defer txb.assertLimits() txb.consumed = append(txb.consumed, req) - txb.postedOutputs = append(txb.postedOutputs, retryOutputFromOnLedgerRequest(req, txb.anchorOutput.AliasID)) - return len(txb.postedOutputs) - 1 } // AddOutput adds an information about posted request. It will produce output // Return adjustment needed for the L2 ledger (adjustment on base tokens related to storage deposit) func (txb *AnchorTransactionBuilder) AddOutput(o iotago.Output) int64 { - if txb.outputsAreFull() { - panic(vmexceptions.ErrOutputLimitExceeded) - } - - defer txb.mustCheckTotalNativeTokensExceeded() + defer txb.assertLimits() storageDeposit := parameters.L1().Protocol.RentStructure.MinRent(o) if o.Deposit() < storageDeposit { @@ -239,27 +229,27 @@ func (txb *AnchorTransactionBuilder) inputs() (iotago.OutputSet, iotago.OutputID // internal native token outputs for _, nativeTokenBalance := range txb.nativeTokenOutputsSorted() { if nativeTokenBalance.requiresExistingAccountingUTXOAsInput() { - outputID := nativeTokenBalance.accountingoutputID + outputID := nativeTokenBalance.accountingInputID outputIDs = append(outputIDs, outputID) - inputs[outputID] = nativeTokenBalance.in + inputs[outputID] = nativeTokenBalance.accountingInput } } // foundries for _, foundry := range txb.foundriesSorted() { if foundry.requiresExistingAccountingUTXOAsInput() { - outputID := foundry.outputID + outputID := foundry.accountingInputID outputIDs = append(outputIDs, outputID) - inputs[outputID] = foundry.in + inputs[outputID] = foundry.accountingInput } } // nfts for _, nft := range txb.nftsSorted() { - if !isc.IsEmptyOutputID(nft.outputID) { - outputID := nft.outputID + if !isc.IsEmptyOutputID(nft.accountingInputID) { + outputID := nft.accountingInputID outputIDs = append(outputIDs, outputID) - inputs[outputID] = nft.in + inputs[outputID] = nft.accountingInput } } @@ -325,7 +315,7 @@ func (txb *AnchorTransactionBuilder) outputs(stateMetadata []byte) iotago.Output // creating outputs for updated foundries foundriesToBeUpdated, _ := txb.FoundriesToBeUpdated() for _, sn := range foundriesToBeUpdated { - ret = append(ret, txb.invokedFoundries[sn].out) + ret = append(ret, txb.invokedFoundries[sn].accountingOutput) } // creating outputs for new NFTs nftOuts := txb.NFTOutputs() @@ -351,7 +341,7 @@ func (txb *AnchorTransactionBuilder) numInputs() int { } } for _, nft := range txb.nftsIncluded { - if !isc.IsEmptyOutputID(nft.outputID) { + if !isc.IsEmptyOutputID(nft.accountingInputID) { ret++ } } diff --git a/packages/vm/vmtxbuilder/txbuilder_test.go b/packages/vm/vmtxbuilder/txbuilder_test.go index 32dcf45cbb..bb8b6f2379 100644 --- a/packages/vm/vmtxbuilder/txbuilder_test.go +++ b/packages/vm/vmtxbuilder/txbuilder_test.go @@ -271,13 +271,6 @@ func TestTxBuilderConsistency(t *testing.T) { runConsume(txb, nativeTokenIDs, runTimes, testAmount, mockedAccounts) }, vmexceptions.ErrInputLimitExceeded) require.Error(t, err, vmexceptions.ErrInputLimitExceeded) - - essence, _ := txb.BuildTransactionEssence(dummyStateMetadata) - txb.MustBalanced() - - essenceBytes, err := essence.Serialize(serializer.DeSeriModeNoValidation, nil) - require.NoError(t, err) - t.Logf("essence bytes len = %d", len(essenceBytes)) }) t.Run("exceeded outputs", func(t *testing.T) { const runTimesInputs = 120 @@ -295,15 +288,7 @@ func TestTxBuilderConsistency(t *testing.T) { addOutput(txb, 1, nativeTokenIDs[idx], mockedAccounts) } }, vmexceptions.ErrOutputLimitExceeded) - require.Error(t, err, vmexceptions.ErrOutputLimitExceeded) - - essence, _ := txb.BuildTransactionEssence(dummyStateMetadata) - txb.MustBalanced() - - essenceBytes, err := essence.Serialize(serializer.DeSeriModeNoValidation, nil) - require.NoError(t, err) - t.Logf("essence bytes len = %d", len(essenceBytes)) }) t.Run("randomize", func(t *testing.T) { const runTimes = 30 diff --git a/packages/webapi/api.go b/packages/webapi/api.go index 526a672625..b6e4b37e9b 100644 --- a/packages/webapi/api.go +++ b/packages/webapi/api.go @@ -48,7 +48,7 @@ func AddHealthEndpoint(server echoswagger.ApiRoot, chainService interfaces.Chain SetSummary("Returns 200 if the node is healthy.") } -func loadControllers(server echoswagger.ApiRoot, mocker *Mocker, controllersToLoad []interfaces.APIController, authMiddleware func() echo.MiddlewareFunc) { +func loadControllers(server echoswagger.ApiRoot, mocker *Mocker, controllersToLoad []interfaces.APIController, authMiddleware echo.MiddlewareFunc) { for _, controller := range controllersToLoad { group := server.Group(controller.Name(), fmt.Sprintf("/v%d/", APIVersion)) controller.RegisterPublic(group, mocker) @@ -66,7 +66,7 @@ func loadControllers(server echoswagger.ApiRoot, mocker *Mocker, controllersToLo } if authMiddleware != nil { - group.EchoGroup().Use(authMiddleware()) + group.EchoGroup().Use(authMiddleware) } controller.RegisterAdmin(adminGroup, mocker) @@ -110,13 +110,7 @@ func Init( userService := services.NewUserService(userManager) // -- - claimValidator := func(claims *authentication.WaspClaims) bool { - // The v2 api uses another way of permission handling, so we can always return true here. - // Permissions are now validated at the route level. See the webapi/v2/controllers/*/controller.go routes. - return true - } - - authMiddleware := authentication.AddV2Authentication(server, userManager, nodeIdentityProvider, authConfig, claimValidator) + authMiddleware := authentication.AddAuthentication(server, userManager, nodeIdentityProvider, authConfig, mocker) controllersToLoad := []interfaces.APIController{ chain.NewChainController(logger, chainService, committeeService, evmService, nodeService, offLedgerService, registryService), diff --git a/packages/webapi/models/core_blocklog.go b/packages/webapi/models/core_blocklog.go index 59c36e28e5..a76000e7a1 100644 --- a/packages/webapi/models/core_blocklog.go +++ b/packages/webapi/models/core_blocklog.go @@ -30,7 +30,7 @@ func MapBlockInfoResponse(info *blocklog.BlockInfo) *BlockInfoResponse { prevAOStr := "" if info.PreviousAliasOutput != nil { blockindex = info.PreviousAliasOutput.GetAliasOutput().StateIndex + 1 - prevAOStr = string(info.PreviousAliasOutput.Bytes()) + prevAOStr = iotago.EncodeHex(info.PreviousAliasOutput.Bytes()) } return &BlockInfoResponse{ BlockIndex: blockindex, diff --git a/packages/webapi/models/mock/AuthInfoModel.json b/packages/webapi/models/mock/AuthInfoModel.json new file mode 100644 index 0000000000..8aca4cda59 --- /dev/null +++ b/packages/webapi/models/mock/AuthInfoModel.json @@ -0,0 +1,4 @@ +{ + "scheme": "jwt", + "authURL": "/auth" +} \ No newline at end of file diff --git a/packages/webapi/models/mock/LoginRequest.json b/packages/webapi/models/mock/LoginRequest.json new file mode 100644 index 0000000000..74d65e01e8 --- /dev/null +++ b/packages/webapi/models/mock/LoginRequest.json @@ -0,0 +1,4 @@ +{ + "username":"wasp", + "password":"wasp" +} \ No newline at end of file diff --git a/packages/webapi/models/mock/LoginResponse.json b/packages/webapi/models/mock/LoginResponse.json new file mode 100644 index 0000000000..02fe1cf4aa --- /dev/null +++ b/packages/webapi/models/mock/LoginResponse.json @@ -0,0 +1,3 @@ +{ + "jwt": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJ3YXNwIiwic3ViIjoid2FzcCIsImF1ZCI6WyJ3YXNwIl0sImV4cCI6MTY4OTk1MTAyNCwibmJmIjoxNjg5ODY0NjI0LCJpYXQiOjE2ODk4NjQ2MjQsImp0aSI6IjE2ODk4NjQ2MjQiLCJwZXJtaXNzaW9ucyI6eyJ3cml0ZSI6e319fQ.LNUuTaoRjEPQyD2nQ00O6NeadiG7nmOEyVIQmGNb1a0" +} \ No newline at end of file diff --git a/tools/cluster/cluster.go b/tools/cluster/cluster.go index f114bf2d24..2c159e4207 100644 --- a/tools/cluster/cluster.go +++ b/tools/cluster/cluster.go @@ -124,13 +124,35 @@ func (clu *Cluster) AddTrustedNode(peerInfo apiclient.PeeringTrustRequest, onNod return nil } -func (clu *Cluster) TrustAll() error { +func (clu *Cluster) Login() ([]string, error) { + allNodes := clu.Config.AllNodes() + jwtTokens := make([]string, len(allNodes)) + for ni := range allNodes { + res, _, err := clu.WaspClient(allNodes[ni]).AuthApi.Authenticate(context.Background()). + LoginRequest(*apiclient.NewLoginRequest("wasp", "wasp")). + Execute() //nolint:bodyclose // false positive + if err != nil { + return nil, err + } + jwtTokens[ni] = "Bearer " + res.Jwt + } + return jwtTokens, nil +} + +func (clu *Cluster) TrustAll(jwtTokens ...string) error { allNodes := clu.Config.AllNodes() allPeers := make([]*apiclient.PeeringNodeIdentityResponse, len(allNodes)) + clients := make([]*apiclient.APIClient, len(allNodes)) + for ni := range allNodes { + clients[ni] = clu.WaspClient(allNodes[ni]) + if jwtTokens != nil { + clients[ni].GetConfig().AddDefaultHeader("Authorization", jwtTokens[ni]) + } + } for ni := range allNodes { var err error //nolint:bodyclose // false positive - if allPeers[ni], _, err = clu.WaspClient(allNodes[ni]).NodeApi.GetPeeringIdentity(context.Background()).Execute(); err != nil { + if allPeers[ni], _, err = clients[ni].NodeApi.GetPeeringIdentity(context.Background()).Execute(); err != nil { return err } } @@ -140,7 +162,7 @@ func (clu *Cluster) TrustAll() error { if ni == pi { continue // dont trust self } - if _, err = clu.WaspClient(allNodes[ni]).NodeApi.TrustPeer(context.Background()).PeeringTrustRequest( + if _, err = clients[ni].NodeApi.TrustPeer(context.Background()).PeeringTrustRequest( apiclient.PeeringTrustRequest{ Name: fmt.Sprintf("%d", pi), PublicKey: allPeers[pi].PublicKey, @@ -534,11 +556,18 @@ func (clu *Cluster) StartAndTrustAll(dataPath string) error { return fmt.Errorf("data path %s does not exist", dataPath) } - if err := clu.Start(); err != nil { + if err = clu.Start(); err != nil { return err } - if err := clu.TrustAll(); err != nil { + var jwtTokens []string + if clu.Config.Wasp[0].AuthScheme == "jwt" { + if jwtTokens, err = clu.Login(); err != nil { + return err + } + } + + if err := clu.TrustAll(jwtTokens...); err != nil { return err } diff --git a/tools/cluster/config.go b/tools/cluster/config.go index 2d44cfb822..cf57d2732f 100644 --- a/tools/cluster/config.go +++ b/tools/cluster/config.go @@ -29,6 +29,7 @@ func (w *WaspConfig) WaspConfigTemplateParams(i int) templates.WaspConfigParams MetricsPort: w.FirstMetricsPort + i, OffledgerBroadcastUpToNPeers: 10, PruningMinStatesToKeep: 10000, + AuthScheme: "none", } } diff --git a/tools/cluster/templates/waspconfig.go b/tools/cluster/templates/waspconfig.go index fae6698a82..6c05368fb9 100644 --- a/tools/cluster/templates/waspconfig.go +++ b/tools/cluster/templates/waspconfig.go @@ -17,6 +17,7 @@ type WaspConfigParams struct { ValidatorKeyPair *cryptolib.KeyPair ValidatorAddress string // bech32 encoded address of ValidatorKeyPair PruningMinStatesToKeep int + AuthScheme string } var WaspConfig = ` @@ -115,17 +116,9 @@ var WaspConfig = ` "enabled": true, "bindAddress": "0.0.0.0:{{.APIPort}}", "auth": { - "scheme": "none", + "scheme": "{{.AuthScheme}}", "jwt": { "duration": "24h" - }, - "basic": { - "username": "wasp" - }, - "ip": { - "whitelist": [ - "0.0.0.0" - ] } }, "limits": { diff --git a/tools/cluster/tests/wasm/inccounter_bg.wasm b/tools/cluster/tests/wasm/inccounter_bg.wasm index ccfba28257..7d88d2fe99 100644 Binary files a/tools/cluster/tests/wasm/inccounter_bg.wasm and b/tools/cluster/tests/wasm/inccounter_bg.wasm differ diff --git a/tools/cluster/tests/wasp-cli_rotation_test.go b/tools/cluster/tests/wasp-cli_rotation_test.go index 486b976fc6..716b9ea4a4 100644 --- a/tools/cluster/tests/wasp-cli_rotation_test.go +++ b/tools/cluster/tests/wasp-cli_rotation_test.go @@ -207,6 +207,6 @@ func TestRotateOnOrigin(t *testing.T) { w.MustRun("chain", "rotate-with-dkg", "--node=1", "--peers=2,3", "--skip-maintenance") // NOTE: must skip "start/stop maintenance" because node1 isn't part of the committee w.MustRun("chain", "deposit", "base:10000000", "--node=1") // deposit works // assert `rotate-with-dkg` works with maintenance (when the node is part of the initial/final committee) - w.MustRun("chain", "rotate-with-dkg", "--node=1") // NOTE: must skip "start/stop maintenance" because node1 isn't part of the committee + w.MustRun("chain", "rotate-with-dkg", "--node=1") w.MustRun("chain", "deposit", "base:10000000", "--node=1") // deposit works } diff --git a/tools/cluster/tests/wasp-cli_test.go b/tools/cluster/tests/wasp-cli_test.go index 8050058133..cc0eb377cb 100644 --- a/tools/cluster/tests/wasp-cli_test.go +++ b/tools/cluster/tests/wasp-cli_test.go @@ -45,6 +45,21 @@ func TestWaspCLINoChains(t *testing.T) { require.Contains(t, out[0], "Total 0 chain(s)") } +func TestWaspAuth(t *testing.T) { + w := newWaspCLITest(t, waspClusterOpts{ + modifyConfig: func(nodeIndex int, configParams templates.WaspConfigParams) templates.WaspConfigParams { + configParams.AuthScheme = "jwt" + return configParams + }, + }) + _, err := w.Run("chain", "list", "--node=0", "--node=0") + require.Error(t, err) + out := w.MustRun("auth", "login", "--node=0", "-u=wasp", "-p=wasp") + require.Equal(t, "Successfully authenticated", out[1]) + out = w.MustRun("chain", "list", "--node=0", "--node=0") + require.Contains(t, out[0], "Total 0 chain(s)") +} + func TestWaspCLI1Chain(t *testing.T) { w := newWaspCLITest(t) diff --git a/tools/evm/iscutils/.npmignore b/tools/evm/iscutils/.npmignore new file mode 100644 index 0000000000..844b7744c3 --- /dev/null +++ b/tools/evm/iscutils/.npmignore @@ -0,0 +1,2 @@ +** +!*.sol \ No newline at end of file diff --git a/tools/evm/iscutils/README.md b/tools/evm/iscutils/README.md new file mode 100644 index 0000000000..512d335caf --- /dev/null +++ b/tools/evm/iscutils/README.md @@ -0,0 +1,40 @@ +# @iota/iscutils + +The iscutils package contains various utility methods to simplify the interaction with the IOTA Magic contract. This utility library is designed to be used with [@iota/iscmagic](https://www.npmjs.com/package/@iota/iscmagic/) npm package. + +The Magic contract, an EVM contract, is deployed by default on every ISC chain. It has several methods, accessed via different interfaces like ISCSandbox, ISCAccounts, ISCUtil and more. These can be utilized within any Solidity contract by importing the @iota/iscmagic library. + +For further information on the Magic contract, check the [Wiki](https://wiki.iota.org/shimmer/smart-contracts/guide/evm/magic/). + +## Installing @iota/iscutils contracts + +The @iota/iscutils contracts are installable via __NPM__ with + +```bash +npm install @iota/iscutils +``` + +After installing `@iota/iscutils` you can use the functions by importing them as you normally would. + +```solidity +pragma solidity >=0.8.5; + +import "@iota/iscmagic/ISC.sol"; +import "@iota/iscutils/prng.sol"; + +contract MyEVMContract { + using PRNG for PRNG.PRNGState; + + event PseudoRNG(uint256 value); + + PRNG.PRNGState private prngState; + + function emitValue() public { + bytes32 e = ISC.sandbox.getEntropy(); + prngState.seed(e); + uint256 random = prngState.generateRandomNumber(); + emit PseudoRNG(random); + } +} + +``` \ No newline at end of file diff --git a/tools/evm/iscutils/package.json b/tools/evm/iscutils/package.json new file mode 100644 index 0000000000..cab237c1fa --- /dev/null +++ b/tools/evm/iscutils/package.json @@ -0,0 +1,23 @@ +{ + "name": "@iota/iscutils", + "version": "0.0.0", + "description": "The iscutils package contains various utility methods to simplify the interaction with the IOTA Magic contract.", + "repository": { + "type": "git", + "url": "https://github.com/iotaledger/wasp.git", + "directory": "tools/evm/iscutils" + }, + "keywords": [ + "iscutils", + "wasp", + "solidity", + "evm", + "isc" + ], + "author": "Iota Smart Contracts", + "license": "Apache-2.0", + "bugs": { + "url": "https://github.com/iotaledger/wasp/issues" + }, + "homepage": "https://github.com/iotaledger/wasp/blob/develop/tools/evm/iscutils/README.md" +} diff --git a/tools/evm/iscutils/prng.sol b/tools/evm/iscutils/prng.sol new file mode 100644 index 0000000000..e546823cb2 --- /dev/null +++ b/tools/evm/iscutils/prng.sol @@ -0,0 +1,32 @@ +// Copyright 2020 IOTA Stiftung +// SPDX-License-Identifier: MIT +pragma solidity >=0.8.5; + +/// @title Pseudorandom Number Generator (PRNG) Library +/// @notice This library is used to generate pseudorandom numbers +/// @dev Not recommended for generating cryptographic secure randomness +library PRNG { + /// @dev Represents the state of the PRNG + struct PRNGState { + bytes32 state; + } + + /// @notice Generate a new pseudorandom number + /// @dev Takes the current state, hashes it and returns the new state. + /// @param self The PRNGState struct to use and alter the state + /// @return The generated pseudorandom number + function generateRandomNumber(PRNGState storage self) internal returns (uint256) { + require(self.state != bytes32(0), "state must be seeded first"); + self.state = keccak256(abi.encodePacked(self.state)); + return uint256(self.state); + } + + /// @notice Seed the PRNG + /// @dev The seed should not be zero + /// @param self The PRNGState struct to update the state + /// @param entropy The seed value (entropy) + function seed(PRNGState storage self, bytes32 entropy) internal { + require(entropy != bytes32(0), "seed must not be zero"); + self.state = entropy; + } +} \ No newline at end of file diff --git a/tools/wasp-cli/chain/rotate.go b/tools/wasp-cli/chain/rotate.go index 8c8948cf02..513650c582 100644 --- a/tools/wasp-cli/chain/rotate.go +++ b/tools/wasp-cli/chain/rotate.go @@ -75,7 +75,7 @@ func initRotateWithDKGCmd() *cobra.Command { withChainFlag(cmd, &chain) cmd.Flags().IntVarP(&quorum, "quorum", "", 0, "quorum (default: 3/4s of the number of committee nodes)") cmd.Flags().BoolVar(&skipMaintenance, "skip-maintenance", false, "quorum (default: 3/4s of the number of committee nodes)") - cmd.Flags().BoolVarP(&offLedger, "off-ledger", "o", false, + cmd.Flags().BoolVarP(&offLedger, "off-ledger", "o", true, "post an off-ledger request", )