Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sasl.Mechanism safe for concurrent use #323

Merged
merged 1 commit into from
Jul 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,16 @@ func (d *Dialer) connect(ctx context.Context, network, address string, connCfg C
// In case of error, this function *does not* close the connection. That is the
// responsibility of the caller.
func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
mech, state, err := d.SASLMechanism.Start(ctx)
if err != nil {
if err := conn.saslHandshake(d.SASLMechanism.Name()); err != nil {
return err
}
err = conn.saslHandshake(mech)

sess, state, err := d.SASLMechanism.Start(ctx)
if err != nil {
return err
}

var completed bool
for !completed {
for completed := false; !completed; {
challenge, err := conn.saslAuthenticate(state)
switch err {
case nil:
Expand All @@ -306,7 +305,7 @@ func (d *Dialer) authenticateSASL(ctx context.Context, conn *Conn) error {
return err
}

completed, state, err = d.SASLMechanism.Next(ctx, challenge)
completed, state, err = sess.Next(ctx, challenge)
if err != nil {
return err
}
Expand Down
11 changes: 9 additions & 2 deletions sasl/plain/plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package plain
import (
"context"
"fmt"

"github.com/segmentio/kafka-go/sasl"
)

// Mechanism implements the PLAIN mechanism and passes the credentials in clear
Expand All @@ -12,8 +14,13 @@ type Mechanism struct {
Password string
}

func (m Mechanism) Start(ctx context.Context) (string, []byte, error) {
return "PLAIN", []byte(fmt.Sprintf("\x00%s\x00%s", m.Username, m.Password)), nil
func (Mechanism) Name() string {
return "PLAIN"
}

func (m Mechanism) Start(ctx context.Context) (sasl.StateMachine, []byte, error) {
// Mechanism is stateless, so it can also implement sasl.Session
return m, []byte(fmt.Sprintf("\x00%s\x00%s", m.Username, m.Password)), nil
}

func (m Mechanism) Next(ctx context.Context, challenge []byte) (bool, []byte, error) {
Expand Down
47 changes: 30 additions & 17 deletions sasl/sasl.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,43 @@ package sasl

import "context"

// Mechanism implements the SASL state machine. It is initialized by calling
// Start at which point the initial bytes should be sent to the server. The
// caller then loops by passing the server's response into Next and then sending
// Next's returned bytes to the server. Eventually either Next will indicate
// that the authentication has been successfully completed or an error will
// cause the state machine to exit prematurely.
// Mechanism implements the SASL state machine for a particular mode of
// authentication. It is used by the kafka.Dialer to perform the SASL
// handshake.
//
// A Mechanism must be re-usable, but it does not need to be safe for concurrent
// access by multiple go routines.
// A Mechanism must be re-usable and safe for concurrent access by multiple
// goroutines.
type Mechanism interface {
// Start begins SASL authentication. It returns the authentication mechanism
// name and "initial response" data (if required by the selected mechanism).
// A non-nil error causes the client to abort the authentication attempt.
// Name returns the identifier for this SASL mechanism. This string will be
// passed to the SASL handshake request and much match one of the mechanisms
// supported by Kafka.
Name() string

// Start begins SASL authentication. It returns an authentication state
// machine and "initial response" data (if required by the selected
// mechanism). A non-nil error causes the client to abort the authentication
// attempt.
//
// A nil ir value is different from a zero-length value. The nil value
// indicates that the selected mechanism does not use an initial response,
// while a zero-length value indicates an empty initial response, which must
// be sent to the server.
//
// In order to ensure that the Mechanism is reusable, calling Start must
// reset any internal state.
Start(ctx context.Context) (mech string, ir []byte, err error)
Start(ctx context.Context) (sess StateMachine, ir []byte, err error)
}

// Next continues challenge-response authentication. A non-nil error causes
// the client to abort the authentication attempt.
// StateMachine implements the SASL challenge/response flow for a single SASL
// handshake. A StateMachine will be created by the Mechanism per connection,
// so it does not need to be safe for concurrent access by multiple goroutines.
//
// Once the StateMachine is created by the Mechanism, the caller loops by
// passing the server's response into Next and then sending Next's returned
// bytes to the server. Eventually either Next will indicate that the
// authentication has been successfully completed via the done return value, or
// it will indicate that the authentication failed by returning a non-nil error.
type StateMachine interface {
// Next continues challenge-response authentication. A non-nil error
// indicates that the client should abort the authentication attempt. If
// the client has been successfully authenticated, then the done return
// value will be true.
Next(ctx context.Context, challenge []byte) (done bool, response []byte, err error)
}
8 changes: 4 additions & 4 deletions sasl/sasl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,18 @@ func TestSASL(t *testing.T) {
}

for _, tt := range tests {
name, _, _ := tt.valid().Start(context.Background())
mech := tt.valid()
if !ktesting.KafkaIsAtLeast(tt.minKafka) {
t.Skip("requires min kafka version " + tt.minKafka)
}

t.Run(name+" success", func(t *testing.T) {
t.Run(mech.Name()+" success", func(t *testing.T) {
testConnect(t, tt.valid(), true)
})
t.Run(name+" failure", func(t *testing.T) {
t.Run(mech.Name()+" failure", func(t *testing.T) {
testConnect(t, tt.invalid(), false)
})
t.Run(name+" is reusable", func(t *testing.T) {
t.Run(mech.Name()+" is reusable", func(t *testing.T) {
mech := tt.valid()
testConnect(t, mech, true)
testConnect(t, mech, true)
Expand Down
26 changes: 18 additions & 8 deletions sasl/scram/scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ var (
type mechanism struct {
algo Algorithm
client *scram.Client
convo *scram.ClientConversation
}

type session struct {
convo *scram.ClientConversation
}

// Mechanism returns a new sasl.Mechanism that will use SCRAM with the provided
Expand All @@ -69,13 +72,20 @@ func Mechanism(algo Algorithm, username, password string) (sasl.Mechanism, error
}, nil
}

func (m *mechanism) Start(ctx context.Context) (string, []byte, error) {
m.convo = m.client.NewConversation()
str, err := m.convo.Step("")
return m.algo.Name(), []byte(str), err
func (m *mechanism) Name() string {
return m.algo.Name()
}

func (m *mechanism) Start(ctx context.Context) (sasl.StateMachine, []byte, error) {
convo := m.client.NewConversation()
str, err := convo.Step("")
if err != nil {
return nil, nil, err
}
return &session{convo: convo}, []byte(str), nil
}

func (m *mechanism) Next(ctx context.Context, challenge []byte) (bool, []byte, error) {
str, err := m.convo.Step(string(challenge))
return m.convo.Done(), []byte(str), err
func (s *session) Next(ctx context.Context, challenge []byte) (bool, []byte, error) {
str, err := s.convo.Step(string(challenge))
return s.convo.Done(), []byte(str), err
}