From 4297256ae3169aad49ff46bf7ec73af3e61630c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastian=20M=C3=BCller?= Date: Mon, 13 May 2024 15:21:04 -0700 Subject: [PATCH] allow construction of paths with syntactically invalid identifiers --- runtime/ast/position.go | 4 +++ runtime/interpreter/interpreter.go | 6 ++-- runtime/interpreter/value.go | 24 ++------------ runtime/literal.go | 9 ++--- runtime/sema/check_path_expression.go | 32 ++++++++++-------- runtime/sema/check_path_expression_test.go | 38 ++++++++++++++-------- runtime/stdlib/account_test.go | 10 +++--- runtime/tests/interpreter/path_test.go | 18 ++++++---- 8 files changed, 70 insertions(+), 71 deletions(-) diff --git a/runtime/ast/position.go b/runtime/ast/position.go index 541d097d40..cd21358ac1 100644 --- a/runtime/ast/position.go +++ b/runtime/ast/position.go @@ -146,6 +146,10 @@ func (e Range) EndPosition(common.MemoryGauge) Position { // NewRangeFromPositioned func NewRangeFromPositioned(memoryGauge common.MemoryGauge, hasPosition HasPosition) Range { + if hasPosition == nil { + return EmptyRange + } + return NewRange( memoryGauge, hasPosition.StartPosition(), diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index 13225041ce..65c7b02961 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -3320,21 +3320,21 @@ var ConverterDeclarations = []ValueConverterDeclaration{ name: sema.PublicPathType.Name, functionType: sema.PublicPathConversionFunctionType, convert: func(interpreter *Interpreter, value Value, _ LocationRange) Value { - return ConvertPublicPath(interpreter, value) + return newPathFromStringValue(interpreter, common.PathDomainPublic, value) }, }, { name: sema.PrivatePathType.Name, functionType: sema.PrivatePathConversionFunctionType, convert: func(interpreter *Interpreter, value Value, _ LocationRange) Value { - return ConvertPrivatePath(interpreter, value) + return newPathFromStringValue(interpreter, common.PathDomainPrivate, value) }, }, { name: sema.StoragePathType.Name, functionType: sema.StoragePathConversionFunctionType, convert: func(interpreter *Interpreter, value Value, _ LocationRange) Value { - return ConvertStoragePath(interpreter, value) + return newPathFromStringValue(interpreter, common.PathDomainStorage, value) }, }, } diff --git a/runtime/interpreter/value.go b/runtime/interpreter/value.go index 11b95b2c04..0da0da9d80 100644 --- a/runtime/interpreter/value.go +++ b/runtime/interpreter/value.go @@ -21646,21 +21646,13 @@ func (PathValue) IsStorable() bool { return true } -func convertPath(interpreter *Interpreter, domain common.PathDomain, value Value) Value { +func newPathFromStringValue(interpreter *Interpreter, domain common.PathDomain, value Value) Value { stringValue, ok := value.(*StringValue) if !ok { return Nil } - _, err := sema.CheckPathLiteral( - domain.Identifier(), - stringValue.Str, - ReturnEmptyRange, - ReturnEmptyRange, - ) - if err != nil { - return Nil - } + // NOTE: any identifier is allowed, it does not have to match the syntax for path literals return NewSomeValueNonCopying( interpreter, @@ -21672,18 +21664,6 @@ func convertPath(interpreter *Interpreter, domain common.PathDomain, value Value ) } -func ConvertPublicPath(interpreter *Interpreter, value Value) Value { - return convertPath(interpreter, common.PathDomainPublic, value) -} - -func ConvertPrivatePath(interpreter *Interpreter, value Value) Value { - return convertPath(interpreter, common.PathDomainPrivate, value) -} - -func ConvertStoragePath(interpreter *Interpreter, value Value) Value { - return convertPath(interpreter, common.PathDomainStorage, value) -} - func (v PathValue) Storable( storage atree.SlabStorage, address atree.Address, diff --git a/runtime/literal.go b/runtime/literal.go index 03dda380aa..8e320e1ac3 100644 --- a/runtime/literal.go +++ b/runtime/literal.go @@ -156,14 +156,11 @@ func pathLiteralValue( pathIdentifier := pathExpression.Identifier.Identifier pathType, err := sema.CheckPathLiteral( + memoryGauge, pathDomain, pathIdentifier, - func() ast.Range { - return ast.NewRangeFromPositioned(memoryGauge, pathExpression.Domain) - }, - func() ast.Range { - return ast.NewRangeFromPositioned(memoryGauge, pathExpression.Identifier) - }, + pathExpression.Domain, + pathExpression.Identifier, ) if err != nil { return nil, InvalidLiteralError diff --git a/runtime/sema/check_path_expression.go b/runtime/sema/check_path_expression.go index b58c96feb9..7cfaf5c396 100644 --- a/runtime/sema/check_path_expression.go +++ b/runtime/sema/check_path_expression.go @@ -23,19 +23,17 @@ import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" + "github.com/onflow/cadence/runtime/errors" ) func (checker *Checker) VisitPathExpression(expression *ast.PathExpression) Type { ty, err := CheckPathLiteral( + checker.memoryGauge, expression.Domain.Identifier, expression.Identifier.Identifier, - func() ast.Range { - return ast.NewRangeFromPositioned(checker.memoryGauge, expression.Domain) - }, - func() ast.Range { - return ast.NewRangeFromPositioned(checker.memoryGauge, expression.Identifier) - }, + expression.Domain, + expression.Identifier, ) checker.report(err) @@ -45,14 +43,20 @@ func (checker *Checker) VisitPathExpression(expression *ast.PathExpression) Type var isValidIdentifier = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`).MatchString -func CheckPathLiteral(domainString, identifier string, domainRangeThunk, idRangeThunk func() ast.Range) (Type, error) { +func CheckPathLiteral( + gauge common.MemoryGauge, + domain string, + identifier string, + domainRange ast.HasPosition, + identifierRange ast.HasPosition, +) (Type, error) { // Check that the domain is valid - domain := common.PathDomainFromIdentifier(domainString) - if domain == common.PathDomainUnknown { + pathDomain := common.PathDomainFromIdentifier(domain) + if pathDomain == common.PathDomainUnknown { return PathType, &InvalidPathDomainError{ - ActualDomain: domainString, - Range: domainRangeThunk(), + ActualDomain: domain, + Range: ast.NewRangeFromPositioned(gauge, domainRange), } } @@ -60,11 +64,11 @@ func CheckPathLiteral(domainString, identifier string, domainRangeThunk, idRange if !isValidIdentifier(identifier) { return PathType, &InvalidPathIdentifierError{ ActualIdentifier: identifier, - Range: idRangeThunk(), + Range: ast.NewRangeFromPositioned(gauge, identifierRange), } } - switch domain { + switch pathDomain { case common.PathDomainStorage: return StoragePathType, nil case common.PathDomainPublic: @@ -72,6 +76,6 @@ func CheckPathLiteral(domainString, identifier string, domainRangeThunk, idRange case common.PathDomainPrivate: return PrivatePathType, nil default: - return PathType, nil + panic(errors.NewUnreachableError()) } } diff --git a/runtime/sema/check_path_expression_test.go b/runtime/sema/check_path_expression_test.go index 1c94aa370f..b4bef47f55 100644 --- a/runtime/sema/check_path_expression_test.go +++ b/runtime/sema/check_path_expression_test.go @@ -23,62 +23,72 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/onflow/cadence/runtime/ast" ) func TestCheckPathLiteral(t *testing.T) { t.Parallel() - rangeThunk := func() ast.Range { - return ast.EmptyRange - } - t.Run("valid domain (storage), valid identifier", func(t *testing.T) { - ty, err := CheckPathLiteral("storage", "test", rangeThunk, rangeThunk) + t.Parallel() + + ty, err := CheckPathLiteral(nil, "storage", "test", nil, nil) require.NoError(t, err) assert.Equal(t, StoragePathType, ty) }) t.Run("valid domain (private), valid identifier", func(t *testing.T) { - ty, err := CheckPathLiteral("private", "test", rangeThunk, rangeThunk) + t.Parallel() + + ty, err := CheckPathLiteral(nil, "private", "test", nil, nil) require.NoError(t, err) assert.Equal(t, PrivatePathType, ty) }) t.Run("valid domain (public), valid identifier", func(t *testing.T) { - ty, err := CheckPathLiteral("public", "test", rangeThunk, rangeThunk) + t.Parallel() + + ty, err := CheckPathLiteral(nil, "public", "test", nil, nil) require.NoError(t, err) assert.Equal(t, PublicPathType, ty) }) t.Run("invalid domain (empty), valid identifier", func(t *testing.T) { - _, err := CheckPathLiteral("", "test", rangeThunk, rangeThunk) + t.Parallel() + + _, err := CheckPathLiteral(nil, "", "test", nil, nil) var invalidPathDomainError *InvalidPathDomainError require.ErrorAs(t, err, &invalidPathDomainError) }) t.Run("invalid domain (foo), valid identifier", func(t *testing.T) { - _, err := CheckPathLiteral("foo", "test", rangeThunk, rangeThunk) + t.Parallel() + + _, err := CheckPathLiteral(nil, "foo", "test", nil, nil) var invalidPathDomainError *InvalidPathDomainError require.ErrorAs(t, err, &invalidPathDomainError) }) t.Run("valid domain (public), invalid identifier (empty)", func(t *testing.T) { - _, err := CheckPathLiteral("public", "", rangeThunk, rangeThunk) + t.Parallel() + + _, err := CheckPathLiteral(nil, "public", "", nil, nil) var invalidPathIdentifierError *InvalidPathIdentifierError require.ErrorAs(t, err, &invalidPathIdentifierError) }) t.Run("valid domain (public), invalid identifier ($)", func(t *testing.T) { - _, err := CheckPathLiteral("public", "$", rangeThunk, rangeThunk) + t.Parallel() + + _, err := CheckPathLiteral(nil, "public", "$", nil, nil) var invalidPathIdentifierError *InvalidPathIdentifierError require.ErrorAs(t, err, &invalidPathIdentifierError) }) t.Run("valid domain (public), invalid identifier (0)", func(t *testing.T) { - _, err := CheckPathLiteral("public", "0", rangeThunk, rangeThunk) + t.Parallel() + + _, err := CheckPathLiteral(nil, "public", "0", nil, nil) var invalidPathIdentifierError *InvalidPathIdentifierError require.ErrorAs(t, err, &invalidPathIdentifierError) }) diff --git a/runtime/stdlib/account_test.go b/runtime/stdlib/account_test.go index 99e302036d..cbff8ee3a0 100644 --- a/runtime/stdlib/account_test.go +++ b/runtime/stdlib/account_test.go @@ -23,7 +23,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/sema" ) @@ -31,10 +30,6 @@ func TestSemaCheckPathLiteralForInternalStorageDomains(t *testing.T) { t.Parallel() - rangeThunk := func() ast.Range { - return ast.EmptyRange - } - internalStorageDomains := []string{ InboxStorageDomain, AccountCapabilityStorageDomain, @@ -44,8 +39,11 @@ func TestSemaCheckPathLiteralForInternalStorageDomains(t *testing.T) { } test := func(domain string) { + t.Run(domain, func(t *testing.T) { - _, err := sema.CheckPathLiteral(domain, "test", rangeThunk, rangeThunk) + t.Parallel() + + _, err := sema.CheckPathLiteral(nil, domain, "test", nil, nil) var invalidPathDomainError *sema.InvalidPathDomainError require.ErrorAs(t, err, &invalidPathDomainError) }) diff --git a/runtime/tests/interpreter/path_test.go b/runtime/tests/interpreter/path_test.go index ca3038b86e..d1a05ac507 100644 --- a/runtime/tests/interpreter/path_test.go +++ b/runtime/tests/interpreter/path_test.go @@ -95,7 +95,7 @@ func TestInterpretConvertStringToPath(t *testing.T) { ) }) - t.Run(fmt.Sprintf("invalid identifier 2: %s", domain.Identifier()), func(t *testing.T) { + t.Run(fmt.Sprintf("syntactically invalid identifier 2: %s", domain.Identifier()), func(t *testing.T) { t.Parallel() @@ -104,19 +104,22 @@ func TestInterpretConvertStringToPath(t *testing.T) { inter := parseCheckAndInterpret(t, fmt.Sprintf( ` - let x = %[1]s(identifier: "2") + let x = %[1]s(identifier: "2")! `, domainType.String(), ), ) assert.Equal(t, - interpreter.Nil, + interpreter.PathValue{ + Domain: domain, + Identifier: "2", + }, inter.Globals.Get("x").GetValue(inter), ) }) - t.Run(fmt.Sprintf("invalid identifier -: %s", domain.Identifier()), func(t *testing.T) { + t.Run(fmt.Sprintf("syntactically invalid identifier -: %s", domain.Identifier()), func(t *testing.T) { t.Parallel() @@ -125,14 +128,17 @@ func TestInterpretConvertStringToPath(t *testing.T) { inter := parseCheckAndInterpret(t, fmt.Sprintf( ` - let x = %[1]s(identifier: "fo-o") + let x = %[1]s(identifier: "fo-o")! `, domainType.String(), ), ) assert.Equal(t, - interpreter.Nil, + interpreter.PathValue{ + Domain: domain, + Identifier: "fo-o", + }, inter.Globals.Get("x").GetValue(inter), ) })