Skip to content

Commit

Permalink
Merge pull request #3449 from onflow/bastian/3080-improve-capabilitiy…
Browse files Browse the repository at this point in the history
…-borrowing
  • Loading branch information
turbolent authored Jul 8, 2024
2 parents 76bd6db + 7833360 commit 2a5e3c4
Show file tree
Hide file tree
Showing 5 changed files with 430 additions and 28 deletions.
137 changes: 124 additions & 13 deletions runtime/capabilities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,20 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {
entitlement X
access(all)
resource R {
resource interface RI {}
access(all)
resource R: RI {
access(all)
let foo: Int
access(X)
let bar: Int
init() {
self.foo = 42
self.bar = 21
}
}
Expand Down Expand Up @@ -128,6 +135,12 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {
let noCap = self.account.capabilities.storage.issue<&R>(/storage/nonExistentTarget)
self.account.capabilities.publish(noCap, at: /public/nonExistentTarget)
let unentitledRICap = self.account.capabilities.storage.issue<&{RI}>(/storage/r)
self.account.capabilities.publish(unentitledRICap, at: /public/unentitledRI)
let entitledRICap = self.account.capabilities.storage.issue<auth(X) &{RI}>(/storage/r)
self.account.capabilities.publish(entitledRICap, at: /public/entitledRI)
}
access(all)
Expand Down Expand Up @@ -252,15 +265,51 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {
access(all)
fun testSwap(): Int {
let ref = self.account.capabilities.get<&R>(/public/r).borrow()!
let ref = self.account.capabilities.get<&R>(/public/r).borrow()!
let r <- self.account.storage.load<@R>(from: /storage/r)
destroy r
let r <- self.account.storage.load<@R>(from: /storage/r)
destroy r
let r2 <- create R2()
self.account.storage.save(<-r2, to: /storage/r)
let r2 <- create R2()
self.account.storage.save(<-r2, to: /storage/r)
return ref.foo
return ref.foo
}
access(all)
fun testRI() {
// Borrow /public/unentitledRI.
// - All unentitled borrows should succeed (as &{RI} / as &R)
// - All entitled borrows should fail (as &{RI} / as &R)
let unentitledRI1 = self.account.capabilities.get<&{RI}>(/public/unentitledRI).borrow()
assert(unentitledRI1 != nil, message: "unentitledRI1 should not be nil")
let entitledRI1 = self.account.capabilities.get<auth(X) &{RI}>(/public/unentitledRI).borrow()
assert(entitledRI1 == nil, message: "entitledRI1 should be nil")
let unentitledR1 = self.account.capabilities.get<&R>(/public/unentitledRI).borrow()
assert(unentitledR1 != nil, message: "unentitledR1 should not be nil")
let entitledR1 = self.account.capabilities.get<auth(X) &R>(/public/unentitledRI).borrow()
assert(entitledR1 == nil, message: "entitledR1 should be nil")
// Borrow /public/entitledRI.
// All borrows should succeed:
// - As &{RI} / as &R
// - Unentitled / entitled
let unentitledRI2 = self.account.capabilities.get<&{RI}>(/public/entitledRI).borrow()
assert(unentitledRI2 != nil, message: "unentitledRI2 should not be nil")
let entitledRI2 = self.account.capabilities.get<auth(X) &{RI}>(/public/entitledRI).borrow()
assert(entitledRI2 != nil, message: "entitledRI2 should not be nil")
let unentitledR2 = self.account.capabilities.get<&R>(/public/entitledRI).borrow()
assert(unentitledR2 != nil, message: "unentitledR2 should not be nil")
let entitledR2 = self.account.capabilities.get<auth(X) &R>(/public/entitledRI).borrow()
assert(entitledR2 != nil, message: "entitledR2 should not be nil")
}
}
`
Expand Down Expand Up @@ -327,6 +376,12 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {

require.ErrorAs(t, err, &interpreter.DereferenceError{})
})

t.Run("testRI", func(t *testing.T) {

_, err := invoke("testRI")
require.NoError(t, err)
})
})

t.Run("struct", func(t *testing.T) {
Expand All @@ -347,13 +402,20 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {
entitlement X
access(all)
struct S {
struct interface SI {}
access(all)
struct S: SI {
access(all)
let foo: Int
access(X)
let bar: Int
init() {
self.foo = 42
self.bar = 21
}
}
Expand Down Expand Up @@ -395,6 +457,12 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {
let noCap = self.account.capabilities.storage.issue<&S>(/storage/nonExistentTarget)
self.account.capabilities.publish(noCap, at: /public/nonExistentTarget)
let unentitledSICap = self.account.capabilities.storage.issue<&{SI}>(/storage/s)
self.account.capabilities.publish(unentitledSICap, at: /public/unentitledSI)
let entitledSICap = self.account.capabilities.storage.issue<auth(X) &{SI}>(/storage/s)
self.account.capabilities.publish(entitledSICap, at: /public/entitledSI)
}
access(all)
Expand Down Expand Up @@ -519,14 +587,51 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {
access(all)
fun testSwap(): Int {
let ref = self.account.capabilities.get<&S>(/public/s).borrow()!
let ref = self.account.capabilities.get<&S>(/public/s).borrow()!
self.account.storage.load<S>(from: /storage/s)
let s2 = S2()
self.account.storage.save(s2, to: /storage/s)
return ref.foo
}
access(all)
fun testSI() {
self.account.storage.load<S>(from: /storage/s)
// Borrow /public/unentitledSI.
// - All unentitled borrows should succeed (as &{SI} / as &S)
// - All entitled borrows should fail (as &{SI} / as &S)
let s2 = S2()
self.account.storage.save(s2, to: /storage/s)
let unentitledSI1 = self.account.capabilities.get<&{SI}>(/public/unentitledSI).borrow()
assert(unentitledSI1 != nil, message: "unentitledSI1 should not be nil")
return ref.foo
let entitledSI1 = self.account.capabilities.get<auth(X) &{SI}>(/public/unentitledSI).borrow()
assert(entitledSI1 == nil, message: "entitledSI1 should be nil")
let unentitledS1 = self.account.capabilities.get<&S>(/public/unentitledSI).borrow()
assert(unentitledS1 != nil, message: "unentitledS1 should not be nil")
let entitledS1 = self.account.capabilities.get<auth(X) &S>(/public/unentitledSI).borrow()
assert(entitledS1 == nil, message: "entitledS1 should be nil")
// Borrow /public/entitledSI.
// All borrows should succeed:
// - As &{SI} / as &S
// - Unentitled / entitled
let unentitledSI2 = self.account.capabilities.get<&{SI}>(/public/entitledSI).borrow()
assert(unentitledSI2 != nil, message: "unentitledSI2 should not be nil")
let entitledSI2 = self.account.capabilities.get<auth(X) &{SI}>(/public/entitledSI).borrow()
assert(entitledSI2 != nil, message: "entitledSI2 should not be nil")
let unentitledS2 = self.account.capabilities.get<&S>(/public/entitledSI).borrow()
assert(unentitledS2 != nil, message: "unentitledS2 should not be nil")
let entitledS2 = self.account.capabilities.get<auth(X) &S>(/public/entitledSI).borrow()
assert(entitledS2 != nil, message: "entitledS2 should not be nil")
}
}
`
Expand Down Expand Up @@ -593,6 +698,12 @@ func TestRuntimeCapability_borrowAndCheck(t *testing.T) {

require.ErrorAs(t, err, &interpreter.DereferenceError{})
})

t.Run("testSI", func(t *testing.T) {

_, err := invoke("testSI")
require.NoError(t, err)
})
})

t.Run("account", func(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions runtime/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ func (interpreter *Interpreter) VisitEnumCaseDeclaration(_ *ast.EnumCaseDeclarat
panic(errors.NewUnreachableError())
}

func (interpreter *Interpreter) substituteMappedEntitlements(ty sema.Type) sema.Type {
func (interpreter *Interpreter) SubstituteMappedEntitlements(ty sema.Type) sema.Type {
if interpreter.SharedState.currentEntitlementMappedValue == nil {
return ty
}
Expand Down Expand Up @@ -1830,7 +1830,7 @@ func (interpreter *Interpreter) transferAndConvert(
nil,
)

targetType = interpreter.substituteMappedEntitlements(targetType)
targetType = interpreter.SubstituteMappedEntitlements(targetType)

result := interpreter.ConvertAndBox(
locationRange,
Expand Down
4 changes: 2 additions & 2 deletions runtime/interpreter/interpreter_expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx
}

castingExpressionTypes := interpreter.Program.Elaboration.CastingExpressionTypes(expression)
expectedType := interpreter.substituteMappedEntitlements(castingExpressionTypes.TargetType)
expectedType := interpreter.SubstituteMappedEntitlements(castingExpressionTypes.TargetType)

switch expression.Operation {
case ast.OperationFailableCast, ast.OperationForceCast:
Expand All @@ -1325,7 +1325,7 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx
// thus this is the only place where it becomes necessary to "instantiate" the result of a map to its
// concrete outputs. In other places (e.g. interface conformance checks) we want to leave maps generic,
// so we don't substitute them.
valueSemaType := interpreter.substituteMappedEntitlements(interpreter.MustSemaTypeOfValue(value))
valueSemaType := interpreter.SubstituteMappedEntitlements(interpreter.MustSemaTypeOfValue(value))
valueStaticType := ConvertSemaToStaticType(interpreter, valueSemaType)
isSubType := interpreter.IsSubTypeOfSemaType(valueStaticType, expectedType)

Expand Down
36 changes: 25 additions & 11 deletions runtime/stdlib/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -3509,6 +3509,25 @@ func newAccountCapabilitiesUnpublishFunction(
}
}

func canBorrow(
wantedBorrowType *sema.ReferenceType,
capabilityBorrowType *sema.ReferenceType,
) bool {

// Ensure the wanted borrow type is not more permissive than the capability borrow type

if !wantedBorrowType.Authorization.
PermitsAccess(capabilityBorrowType.Authorization) {

return false
}

// Ensure the wanted borrow type is a subtype or supertype of the capability borrow type

return sema.IsSubType(wantedBorrowType.Type, capabilityBorrowType.Type) ||
sema.IsSubType(capabilityBorrowType.Type, wantedBorrowType.Type)
}

func getCheckedCapabilityController(
inter *interpreter.Interpreter,
capabilityAddressValue interpreter.AddressValue,
Expand All @@ -3519,15 +3538,14 @@ func getCheckedCapabilityController(
interpreter.CapabilityControllerValue,
*sema.ReferenceType,
) {

if wantedBorrowType == nil {
wantedBorrowType = capabilityBorrowType
} else if !sema.IsSubType(capabilityBorrowType, wantedBorrowType) {
// Ensure wanted borrow type is not more permissive
// than the capability's borrow type:
// The wanted type must be a supertype
} else {
wantedBorrowType = inter.SubstituteMappedEntitlements(wantedBorrowType).(*sema.ReferenceType)

return nil, nil
if !canBorrow(wantedBorrowType, capabilityBorrowType) {
return nil, nil
}
}

capabilityAddress := capabilityAddressValue.ToAddress()
Expand All @@ -3538,10 +3556,6 @@ func getCheckedCapabilityController(
return nil, nil
}

// Ensure wanted borrow type is not more permissive
// than the controller's borrow type:
// The wanted type must be a supertype

controllerBorrowStaticType := controller.CapabilityControllerBorrowType()

controllerBorrowType, ok :=
Expand All @@ -3550,7 +3564,7 @@ func getCheckedCapabilityController(
panic(errors.NewUnreachableError())
}

if !sema.IsSubType(controllerBorrowType, wantedBorrowType) {
if !canBorrow(wantedBorrowType, controllerBorrowType) {
return nil, nil
}

Expand Down
Loading

0 comments on commit 2a5e3c4

Please sign in to comment.