Skip to content

Commit

Permalink
Add support for sticky termininator selection on dials. Fixes #2019
Browse files Browse the repository at this point in the history
  • Loading branch information
plorenz committed May 7, 2024
1 parent 611f64c commit 3ef7eb2
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 51 deletions.
2 changes: 2 additions & 0 deletions common/ctrl_msg/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ const (
InitiatorLocalAddressHeader = 1112
InitiatorRemoteAddressHeader = 1113

XtStickinessToken

ErrorTypeGeneric = 0
ErrorTypeInvalidTerminator = 1
ErrorTypeMisconfiguredTerminator = 2
Expand Down
2 changes: 2 additions & 0 deletions controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/openziti/ziti/controller/event"
"github.com/openziti/ziti/controller/events"
"github.com/openziti/ziti/controller/handler_peer_ctrl"
"github.com/openziti/ziti/controller/xt_sticky"
"github.com/openziti/ziti/controller/zac"
"math/big"
"os"
Expand Down Expand Up @@ -446,6 +447,7 @@ func (c *Controller) registerXts() {
xt.GlobalRegistry().RegisterFactory(xt_smartrouting.NewFactory())
xt.GlobalRegistry().RegisterFactory(xt_random.NewFactory())
xt.GlobalRegistry().RegisterFactory(xt_weighted.NewFactory())
xt.GlobalRegistry().RegisterFactory(xt_sticky.NewFactory())
}

func (c *Controller) registerComponents() error {
Expand Down
29 changes: 16 additions & 13 deletions controller/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,6 @@ func (network *Network) RerouteLink(l *Link) {
}

func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, error) {
srcR := params.GetSourceRouter()
clientId := params.GetClientId()
service := params.GetServiceId()
ctx := params.GetLogContext()
Expand Down Expand Up @@ -559,7 +558,7 @@ func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, err
logger = logger.WithField("serviceName", svc.Name)

// 3: select terminator
strategy, terminator, pathNodes, circuitErr := network.selectPath(srcR, svc, instanceId, ctx)
strategy, terminator, pathNodes, strategyData, circuitErr := network.selectPath(params, svc, instanceId, ctx)
if circuitErr != nil {
network.CircuitFailedEvent(circuitId, params, startTime, nil, nil, circuitErr.Cause())
network.ServiceDialOtherError(serviceId)
Expand Down Expand Up @@ -653,6 +652,10 @@ func (network *Network) CreateCircuit(params CreateCircuitParams) (*Circuit, err
delete(peerData, uint32(ctrl_msg.TerminatorLocalAddressHeader))
delete(peerData, uint32(ctrl_msg.TerminatorRemoteAddressHeader))

for k, v := range strategyData {
peerData[k] = v
}

now := time.Now()
// 6: Create Circuit Object
circuit := &Circuit{
Expand Down Expand Up @@ -697,7 +700,7 @@ func parseInstanceIdAndService(service string) (string, string) {
return identityId, serviceId
}

func (network *Network) selectPath(srcR *Router, svc *Service, instanceId string, ctx logcontext.Context) (xt.Strategy, xt.CostedTerminator, []*Router, CircuitError) {
func (network *Network) selectPath(params CreateCircuitParams, svc *Service, instanceId string, ctx logcontext.Context) (xt.Strategy, xt.CostedTerminator, []*Router, xt.PeerData, CircuitError) {
paths := map[string]*PathAndCost{}
var weightedTerminators []xt.CostedTerminator
var errList []error
Expand Down Expand Up @@ -725,7 +728,7 @@ func (network *Network) selectPath(srcR *Router, svc *Service, instanceId string
continue
}

path, cost, err := network.shortestPath(srcR, dstR)
path, cost, err := network.shortestPath(params.GetSourceRouter(), dstR)
if err != nil {
log.Debugf("error while calculating path for service %v: %v", svc.Id, err)
errList = append(errList, err)
Expand All @@ -748,38 +751,38 @@ func (network *Network) selectPath(srcR *Router, svc *Service, instanceId string
}

if len(svc.Terminators) == 0 {
return nil, nil, nil, newCircuitErrorf(CircuitFailureNoTerminators, "service %v has no terminators", svc.Id)
return nil, nil, nil, nil, newCircuitErrorf(CircuitFailureNoTerminators, "service %v has no terminators", svc.Id)
}

if len(weightedTerminators) == 0 {
if pathError {
return nil, nil, nil, newCircuitErrWrap(CircuitFailureNoPath, errorz.MultipleErrors(errList))
return nil, nil, nil, nil, newCircuitErrWrap(CircuitFailureNoPath, errorz.MultipleErrors(errList))
}

if hasOfflineRouters {
return nil, nil, nil, newCircuitErrorf(CircuitFailureNoOnlineTerminators, "service %v has no online terminators for instanceId %v", svc.Id, instanceId)
return nil, nil, nil, nil, newCircuitErrorf(CircuitFailureNoOnlineTerminators, "service %v has no online terminators for instanceId %v", svc.Id, instanceId)
}

return nil, nil, nil, newCircuitErrorf(CircuitFailureNoTerminators, "service %v has no terminators for instanceId %v", svc.Id, instanceId)
return nil, nil, nil, nil, newCircuitErrorf(CircuitFailureNoTerminators, "service %v has no terminators for instanceId %v", svc.Id, instanceId)
}

strategy, err := network.strategyRegistry.GetStrategy(svc.TerminatorStrategy)
if err != nil {
return nil, nil, nil, newCircuitErrWrap(CircuitFailureInvalidStrategy, err)
return nil, nil, nil, nil, newCircuitErrWrap(CircuitFailureInvalidStrategy, err)
}

sort.Slice(weightedTerminators, func(i, j int) bool {
return weightedTerminators[i].GetRouteCost() < weightedTerminators[j].GetRouteCost()
})

terminator, err := strategy.Select(weightedTerminators)
terminator, peerData, err := strategy.Select(params, weightedTerminators)

if err != nil {
return nil, nil, nil, newCircuitErrorf(CircuitFailureStrategyError, "strategy %v errored selecting terminator for service %v: %v", svc.TerminatorStrategy, svc.Id, err)
return nil, nil, nil, nil, newCircuitErrorf(CircuitFailureStrategyError, "strategy %v errored selecting terminator for service %v: %v", svc.TerminatorStrategy, svc.Id, err)
}

if terminator == nil {
return nil, nil, nil, newCircuitErrorf(CircuitFailureStrategyError, "strategy %v did not select terminator for service %v", svc.TerminatorStrategy, svc.Id)
return nil, nil, nil, nil, newCircuitErrorf(CircuitFailureStrategyError, "strategy %v did not select terminator for service %v", svc.TerminatorStrategy, svc.Id)
}

path := paths[terminator.GetRouterId()].path
Expand All @@ -803,7 +806,7 @@ func (network *Network) selectPath(srcR *Router, svc *Service, instanceId string
log.Debugf("selected terminator %v for path %v from %v", terminator.GetId(), pathStr, buf.String())
}

return strategy, terminator, path, nil
return strategy, terminator, path, peerData, nil
}

func (network *Network) RemoveCircuit(circuitId string, now bool) error {
Expand Down
8 changes: 2 additions & 6 deletions controller/xt/costs.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,12 @@ func (p *precedence) IsRequired() bool {
return p.minCost == requireMinCost
}

func (p *precedence) getMinCost() uint32 {
func (p *precedence) GetBaseCost() uint32 {
return p.minCost
}

func (p *precedence) getMaxCost() uint32 {
return p.maxCost
}

func (p *precedence) GetBiasedCost(cost uint32) uint32 {
result := p.getMinCost() + cost
result := p.minCost + cost
if result > p.maxCost {
return p.maxCost
}
Expand Down
13 changes: 10 additions & 3 deletions controller/xt/xt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package xt

import (
"fmt"
"github.com/openziti/identity"
"github.com/openziti/ziti/common/logcontext"
"time"
)

Expand Down Expand Up @@ -60,16 +62,21 @@ type StrategyChangeEvent interface {
GetRemoved() []Terminator
}

type CreateCircuitParams interface {
GetServiceId() string
GetClientId() *identity.TokenId
GetLogContext() logcontext.Context
}

type Strategy interface {
Select(terminators []CostedTerminator) (CostedTerminator, error)
Select(param CreateCircuitParams, terminators []CostedTerminator) (CostedTerminator, PeerData, error)
HandleTerminatorChange(event StrategyChangeEvent) error
NotifyEvent(event TerminatorEvent)
}

type Precedence interface {
fmt.Stringer
getMinCost() uint32
getMaxCost() uint32
GetBaseCost() uint32
IsFailed() bool
IsDefault() bool
IsRequired() bool
Expand Down
6 changes: 3 additions & 3 deletions controller/xt_random/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ func (self *factory) NewStrategy() xt.Strategy {

type strategy struct{}

func (self *strategy) Select(terminators []xt.CostedTerminator) (xt.CostedTerminator, error) {
func (self *strategy) Select(_ xt.CreateCircuitParams, terminators []xt.CostedTerminator) (xt.CostedTerminator, xt.PeerData, error) {
terminators = xt.GetRelatedTerminators(terminators)
count := len(terminators)
if count == 1 {
return terminators[0], nil
return terminators[0], nil, nil
}
selected := rand.Intn(count)
return terminators[selected], nil
return terminators[selected], nil, nil
}

func (self *strategy) NotifyEvent(xt.TerminatorEvent) {}
Expand Down
6 changes: 3 additions & 3 deletions controller/xt_smartrouting/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const (
)

/**
The smartrouting strategy relies purely on maninpulating costs and lets the smart routing algorithm pick the terminator.
The smart routing strategy relies purely on manipulating costs and lets the smart routing algorithm pick the terminator.
It increases costs by a small amount when a new circuit uses the terminator and drops it back down when the circuit
closes. It also increases the cost whenever a dial fails and decreases it whenever a dial succeeds. Dial successes
will only reduce costs by the amount that failures have previously increased it.
Expand Down Expand Up @@ -59,8 +59,8 @@ type strategy struct {
xt_common.CostVisitor
}

func (self *strategy) Select(terminators []xt.CostedTerminator) (xt.CostedTerminator, error) {
return terminators[0], nil
func (self *strategy) Select(_ xt.CreateCircuitParams, terminators []xt.CostedTerminator) (xt.CostedTerminator, xt.PeerData, error) {
return terminators[0], nil, nil
}

func (self *strategy) NotifyEvent(event xt.TerminatorEvent) {
Expand Down
99 changes: 99 additions & 0 deletions controller/xt_sticky/impl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
Copyright NetFoundry Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package xt_sticky

import (
"github.com/openziti/ziti/common/ctrl_msg"
"github.com/openziti/ziti/controller/xt"
"github.com/openziti/ziti/controller/xt_common"
"math"
"time"
)

const (
Name = "sticky"
)

/**
The sticky strategy uses the smart routing strategy to select an initial terminator for a client. It also
returns a token to the client which can be passed back in on subsequent dials. If the token is passed
back in, then strategy will try to use the same terminator. If it's not available a different terminator
will be selected and a different token will be returned.
*/

func NewFactory() xt.Factory {
return &factory{}
}

type factory struct{}

func (self *factory) GetStrategyName() string {
return Name
}

func (self *factory) NewStrategy() xt.Strategy {
strategy := strategy{
CostVisitor: xt_common.CostVisitor{
FailureCosts: xt.NewFailureCosts(math.MaxUint16/4, 20, 2),
CircuitCost: 2,
},
}
strategy.CostVisitor.FailureCosts.CreditOverTime(5, time.Minute)
return &strategy
}

type strategy struct {
xt_common.CostVisitor
}

func (self *strategy) Select(params xt.CreateCircuitParams, terminators []xt.CostedTerminator) (xt.CostedTerminator, xt.PeerData, error) {
id := params.GetClientId()
var result xt.CostedTerminator

terminators = xt.GetRelatedTerminators(terminators)

if id != nil {
if terminatorId, ok := id.Data[ctrl_msg.XtStickinessToken]; ok {
strId := string(terminatorId)
for _, terminator := range terminators {
if terminator.GetId() == strId {
result = terminator
break
}
}
}
}

if result == nil {
result = terminators[0]
}

return result, xt.PeerData{
ctrl_msg.XtStickinessToken: []byte(result.GetId()),
}, nil
}

func (self *strategy) NotifyEvent(event xt.TerminatorEvent) {
event.Accept(&self.CostVisitor)
}

func (self *strategy) HandleTerminatorChange(event xt.StrategyChangeEvent) error {
for _, t := range event.GetRemoved() {
self.FailureCosts.Clear(t.GetId())
}
return nil
}
8 changes: 4 additions & 4 deletions controller/xt_weighted/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ type strategy struct {
xt_common.CostVisitor
}

func (self *strategy) Select(terminators []xt.CostedTerminator) (xt.CostedTerminator, error) {
func (self *strategy) Select(_ xt.CreateCircuitParams, terminators []xt.CostedTerminator) (xt.CostedTerminator, xt.PeerData, error) {
terminators = xt.GetRelatedTerminators(terminators)
if len(terminators) == 1 {
return terminators[0], nil
return terminators[0], nil, nil
}

var costIdx []float32
Expand All @@ -81,11 +81,11 @@ func (self *strategy) Select(terminators []xt.CostedTerminator) (xt.CostedTermin
selected := rand.Float32()
for idx, cost := range costIdx {
if selected < cost {
return terminators[idx], nil
return terminators[idx], nil, nil
}
}

return terminators[0], nil
return terminators[0], nil, nil
}

func (self *strategy) NotifyEvent(event xt.TerminatorEvent) {
Expand Down
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ require (
github.com/openziti/jwks v1.0.3
github.com/openziti/metrics v1.2.51
github.com/openziti/runzmd v1.0.43
github.com/openziti/sdk-golang v0.23.32
github.com/openziti/sdk-golang v0.21.0-alpha-1.0.20240506224004-e6cf97a2cae1
github.com/openziti/secretstream v0.1.19
github.com/openziti/storage v0.2.37
github.com/openziti/transport/v2 v2.0.131
Expand All @@ -67,7 +67,7 @@ require (
github.com/rabbitmq/amqp091-go v1.8.1
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475
github.com/russross/blackfriday v1.6.0
github.com/shirou/gopsutil/v3 v3.24.3
github.com/shirou/gopsutil/v3 v3.24.4
github.com/sirupsen/logrus v1.9.3
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
github.com/spf13/cobra v1.8.0
Expand All @@ -84,7 +84,7 @@ require (
golang.org/x/sync v0.7.0
golang.org/x/sys v0.19.0
golang.org/x/text v0.14.0
google.golang.org/protobuf v1.33.0
google.golang.org/protobuf v1.34.0
gopkg.in/AlecAivazis/survey.v1 v1.8.8
gopkg.in/resty.v1 v1.12.0
gopkg.in/square/go-jose.v2 v2.6.0
Expand Down
Loading

0 comments on commit 3ef7eb2

Please sign in to comment.