diff --git a/depinject/config.go b/depinject/config.go index 947539c4b34f..78a656de6d9f 100644 --- a/depinject/config.go +++ b/depinject/config.go @@ -11,6 +11,12 @@ type Config interface { apply(*container) error } +var fallbackValues []interface{} + +func AddFallbackSupply(values ...interface{}) { + fallbackValues = append(fallbackValues, values...) +} + // Provide defines a container configuration which registers the provided dependency // injection providers. Each provider will be called at most once with the // exception of module-scoped providers which are called at most once per module @@ -149,7 +155,7 @@ func Supply(values ...interface{}) Config { loc := LocationFromCaller(1) return containerConfig(func(ctr *container) error { for _, v := range values { - err := ctr.supply(reflect.ValueOf(v), loc) + err := ctr.supply(reflect.ValueOf(v), loc, false) if err != nil { return errors.WithStack(err) } @@ -158,6 +164,23 @@ func Supply(values ...interface{}) Config { }) } +// FallbackSupply works like Supply except that anything supplied with it will +// only be used when no other value of the same type is available, meaning that +// it can be overridden by another Supply. +func fallbackSupply() Config { + loc := LocationFromCaller(1) + return containerConfig(func(ctr *container) error { + for _, v := range fallbackValues { + err := ctr.supply(reflect.ValueOf(v), loc, true) + if err != nil { + // ignore errors coming from default configs + ctr.logf("Ignoring error from fallback config: %s", errors.WithStack(err)) + } + } + return nil + }) +} + // Error defines configuration which causes the dependency injection container to // fail immediately. func Error(err error) Config { diff --git a/depinject/container.go b/depinject/container.go index 945d42c3130b..250f524082f7 100644 --- a/depinject/container.go +++ b/depinject/container.go @@ -153,19 +153,26 @@ func (c *container) getResolver(typ reflect.Type, key *moduleKey) (resolver, err if !found && typ.Kind() == reflect.Interface { matches := map[reflect.Type]reflect.Type{} - var resolverType reflect.Type + defaultMatches := map[reflect.Type]reflect.Type{} for _, r := range c.resolvers { if r.getType().Kind() != reflect.Interface && r.getType().Implements(typ) { - resolverType = r.getType() - matches[resolverType] = resolverType + if defRes, ok := r.(fallbackResolver); ok && defRes.isFallback() { + defaultMatches[r.getType()] = r.getType() + } else { + matches[r.getType()] = r.getType() + } } } if len(matches) == 1 { - res, _ = c.resolverByType(resolverType) - c.logf("Implicitly registering resolver %v for interface type %v", resolverType, typ) + res, _ = c.resolverByType(getResolverType(matches)) + c.logf("Implicitly registering resolver %v for interface type %v", res.getType(), typ) + c.addResolver(typ, res) + } else if len(matches) == 0 && len(defaultMatches) == 1 { + res, _ = c.resolverByType(getResolverType(defaultMatches)) + c.logf("Implicitly registering resolver %v for interface type %v (using the fallback value)", res.getType(), typ) c.addResolver(typ, res) - } else if len(matches) > 1 { + } else if len(matches) > 1 || len(defaultMatches) > 1 { return nil, newErrMultipleImplicitInterfaceBindings(typ, matches) } } @@ -173,6 +180,14 @@ func (c *container) getResolver(typ reflect.Type, key *moduleKey) (resolver, err return res, nil } +func getResolverType(matches map[reflect.Type]reflect.Type) reflect.Type { + for _, v := range matches { + return v + } + + return nil +} + func (c *container) getExplicitResolver(typ reflect.Type, key *moduleKey) (resolver, error) { var pref interfaceBinding var found bool @@ -339,7 +354,7 @@ func (c *container) addNode(provider *providerDescriptor, key *moduleKey) (inter return node, nil } -func (c *container) supply(value reflect.Value, location Location) error { +func (c *container) supply(value reflect.Value, location Location, isFallback bool) error { typ := value.Type() locGrapNode := c.locationGraphNode(location, nil) markGraphNodeAsUsed(locGrapNode) @@ -355,6 +370,7 @@ func (c *container) supply(value reflect.Value, location Location) error { value: value, loc: location, graphNode: typeGraphNode, + fallback: isFallback, }) return nil diff --git a/depinject/container_test.go b/depinject/container_test.go index 43010004cc35..f268471413b3 100644 --- a/depinject/container_test.go +++ b/depinject/container_test.go @@ -791,3 +791,54 @@ func TestFuncTypes(t *testing.T) { require.True(t, ok) require.NoError(t, err) } + +func TestDefaultBindings(t *testing.T) { + smallDuck := smallMallard{} + // this can be used in an init() function + depinject.AddFallbackSupply( + NewLogger(true), + smallDuck, + func() Duck { return smallDuck }, + ) + + var logger Logger + err := depinject.Inject( + depinject.Configs( + depinject.Supply(smallDuck), + ), + &logger) + require.NoError(t, err) + + _, ok := logger.(NoOpLogger) + require.True(t, ok) + + err = depinject.Inject( + depinject.Configs( + depinject.Supply(NewLogger(false)), + depinject.Supply(smallDuck), + ), + &logger) + require.NoError(t, err) + + _, ok = logger.(LoggerImpl) + require.True(t, ok) +} + +type Logger interface { + Log(string) +} + +type LoggerImpl struct{} + +func (LoggerImpl) Log(string) {} + +type NoOpLogger struct{} + +func (NoOpLogger) Log(string) {} + +func NewLogger(isNoOp bool) Logger { + if isNoOp { + return NoOpLogger{} + } + return LoggerImpl{} +} diff --git a/depinject/inject.go b/depinject/inject.go index 609845781a7b..2a9f6a80a9e0 100644 --- a/depinject/inject.go +++ b/depinject/inject.go @@ -16,6 +16,10 @@ package depinject // Use InjectDebug to configure debug behavior. func Inject(containerConfig Config, outputs ...interface{}) error { loc := LocationFromCaller(1) + containerConfig = Configs( + containerConfig, + fallbackSupply(), + ) return inject(loc, AutoDebug(), containerConfig, outputs...) } diff --git a/depinject/resolver.go b/depinject/resolver.go index 352b7bf11e8e..893f29ab3497 100644 --- a/depinject/resolver.go +++ b/depinject/resolver.go @@ -13,3 +13,8 @@ type resolver interface { typeGraphNode() *graphviz.Node getType() reflect.Type } + +type fallbackResolver interface { + resolver + isFallback() bool +} diff --git a/depinject/supply.go b/depinject/supply.go index 4e3a8bf9f92d..46e43c39f495 100644 --- a/depinject/supply.go +++ b/depinject/supply.go @@ -11,6 +11,7 @@ type supplyResolver struct { value reflect.Value loc Location graphNode *graphviz.Node + fallback bool } func (s supplyResolver) getType() reflect.Type { @@ -33,3 +34,7 @@ func (s supplyResolver) resolve(c *container, _ *moduleKey, caller Location) (re func (s supplyResolver) typeGraphNode() *graphviz.Node { return s.graphNode } + +func (s supplyResolver) isFallback() bool { + return s.fallback +}