Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(depinject): fallback Supply #17046

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion depinject/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ type Config interface {
apply(*container) error
}

var fallbackValues []interface{}

func AddFallbackSupply(values ...interface{}) {
fallbackValues = append(fallbackValues, values...)
}

// Provide defines a container configuration which registers the provided dependency
// injection providers. Each provider will be called at most once with the
// exception of module-scoped providers which are called at most once per module
Expand Down Expand Up @@ -149,7 +155,7 @@ func Supply(values ...interface{}) Config {
loc := LocationFromCaller(1)
return containerConfig(func(ctr *container) error {
for _, v := range values {
err := ctr.supply(reflect.ValueOf(v), loc)
err := ctr.supply(reflect.ValueOf(v), loc, false)
if err != nil {
return errors.WithStack(err)
}
Expand All @@ -158,6 +164,23 @@ func Supply(values ...interface{}) Config {
})
}

// FallbackSupply works like Supply except that anything supplied with it will
// only be used when no other value of the same type is available, meaning that
// it can be overridden by another Supply.
func fallbackSupply() Config {
loc := LocationFromCaller(1)
return containerConfig(func(ctr *container) error {
for _, v := range fallbackValues {
err := ctr.supply(reflect.ValueOf(v), loc, true)
if err != nil {
// ignore errors coming from default configs
ctr.logf("Ignoring error from fallback config: %s", errors.WithStack(err))
}
}
return nil
})
}

// Error defines configuration which causes the dependency injection container to
// fail immediately.
func Error(err error) Config {
Expand Down
30 changes: 23 additions & 7 deletions depinject/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,26 +153,41 @@ func (c *container) getResolver(typ reflect.Type, key *moduleKey) (resolver, err

if !found && typ.Kind() == reflect.Interface {
matches := map[reflect.Type]reflect.Type{}
var resolverType reflect.Type
defaultMatches := map[reflect.Type]reflect.Type{}
for _, r := range c.resolvers {
if r.getType().Kind() != reflect.Interface && r.getType().Implements(typ) {
resolverType = r.getType()
matches[resolverType] = resolverType
if defRes, ok := r.(fallbackResolver); ok && defRes.isFallback() {
defaultMatches[r.getType()] = r.getType()
} else {
matches[r.getType()] = r.getType()
}
}
}

if len(matches) == 1 {
res, _ = c.resolverByType(resolverType)
c.logf("Implicitly registering resolver %v for interface type %v", resolverType, typ)
res, _ = c.resolverByType(getResolverType(matches))
c.logf("Implicitly registering resolver %v for interface type %v", res.getType(), typ)
c.addResolver(typ, res)
} else if len(matches) == 0 && len(defaultMatches) == 1 {
res, _ = c.resolverByType(getResolverType(defaultMatches))
c.logf("Implicitly registering resolver %v for interface type %v (using the fallback value)", res.getType(), typ)
c.addResolver(typ, res)
} else if len(matches) > 1 {
} else if len(matches) > 1 || len(defaultMatches) > 1 {
return nil, newErrMultipleImplicitInterfaceBindings(typ, matches)
}
}

return res, nil
}

func getResolverType(matches map[reflect.Type]reflect.Type) reflect.Type {
for _, v := range matches {
return v
}

return nil
}

func (c *container) getExplicitResolver(typ reflect.Type, key *moduleKey) (resolver, error) {
var pref interfaceBinding
var found bool
Expand Down Expand Up @@ -339,7 +354,7 @@ func (c *container) addNode(provider *providerDescriptor, key *moduleKey) (inter
return node, nil
}

func (c *container) supply(value reflect.Value, location Location) error {
func (c *container) supply(value reflect.Value, location Location, isFallback bool) error {
typ := value.Type()
locGrapNode := c.locationGraphNode(location, nil)
markGraphNodeAsUsed(locGrapNode)
Expand All @@ -355,6 +370,7 @@ func (c *container) supply(value reflect.Value, location Location) error {
value: value,
loc: location,
graphNode: typeGraphNode,
fallback: isFallback,
})

return nil
Expand Down
51 changes: 51 additions & 0 deletions depinject/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -791,3 +791,54 @@ func TestFuncTypes(t *testing.T) {
require.True(t, ok)
require.NoError(t, err)
}

func TestDefaultBindings(t *testing.T) {
smallDuck := smallMallard{}
// this can be used in an init() function
depinject.AddFallbackSupply(
NewLogger(true),
smallDuck,
func() Duck { return smallDuck },
)

var logger Logger
err := depinject.Inject(
depinject.Configs(
depinject.Supply(smallDuck),
),
&logger)
require.NoError(t, err)

_, ok := logger.(NoOpLogger)
require.True(t, ok)

err = depinject.Inject(
depinject.Configs(
depinject.Supply(NewLogger(false)),
depinject.Supply(smallDuck),
),
&logger)
require.NoError(t, err)

_, ok = logger.(LoggerImpl)
require.True(t, ok)
}

type Logger interface {
Log(string)
}

type LoggerImpl struct{}

func (LoggerImpl) Log(string) {}

type NoOpLogger struct{}

func (NoOpLogger) Log(string) {}

func NewLogger(isNoOp bool) Logger {
if isNoOp {
return NoOpLogger{}
}
return LoggerImpl{}
}
4 changes: 4 additions & 0 deletions depinject/inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ package depinject
// Use InjectDebug to configure debug behavior.
func Inject(containerConfig Config, outputs ...interface{}) error {
loc := LocationFromCaller(1)
containerConfig = Configs(
containerConfig,
fallbackSupply(),
)
return inject(loc, AutoDebug(), containerConfig, outputs...)
}

Expand Down
5 changes: 5 additions & 0 deletions depinject/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ type resolver interface {
typeGraphNode() *graphviz.Node
getType() reflect.Type
}

type fallbackResolver interface {
resolver
isFallback() bool
}
5 changes: 5 additions & 0 deletions depinject/supply.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type supplyResolver struct {
value reflect.Value
loc Location
graphNode *graphviz.Node
fallback bool
}

func (s supplyResolver) getType() reflect.Type {
Expand All @@ -33,3 +34,7 @@ func (s supplyResolver) resolve(c *container, _ *moduleKey, caller Location) (re
func (s supplyResolver) typeGraphNode() *graphviz.Node {
return s.graphNode
}

func (s supplyResolver) isFallback() bool {
return s.fallback
}