Skip to content

Commit

Permalink
Merge pull request #667 from ekovacs/feature/allow-custom-tls-config-…
Browse files Browse the repository at this point in the history
…for-acceptor

Allow the clients of acceptor to specify their own tls.Config
  • Loading branch information
ackleymi authored Sep 4, 2024
2 parents 8a53aa9 + f4d6fe9 commit 5ec1219
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
43 changes: 43 additions & 0 deletions accepter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
package quickfix

import (
"crypto/tls"
"net"
"testing"

"github.com/quickfixgo/quickfix/config"

proxyproto "github.com/pires/go-proxyproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestAcceptor_Start(t *testing.T) {
Expand Down Expand Up @@ -83,3 +85,44 @@ func TestAcceptor_Start(t *testing.T) {
})
}
}

func TestAcceptor_SetTLSConfig(t *testing.T) {
sessionSettings := NewSessionSettings()
sessionSettings.Set(config.BeginString, BeginStringFIX42)
sessionSettings.Set(config.SenderCompID, "sender")
sessionSettings.Set(config.TargetCompID, "target")

genericSettings := NewSettings()

genericSettings.GlobalSettings().Set("SocketAcceptPort", "5001")
_, err := genericSettings.AddSession(sessionSettings)
require.NoError(t, err)

logger, err := NewScreenLogFactory().Create()
require.NoError(t, err)
acceptor := &Acceptor{settings: genericSettings, globalLog: logger}
defer acceptor.Stop()
// example of a customized tls.Config that loads the certificates dynamically by the `GetCertificate` function
// as opposed to the Certificates slice, that is static in nature, and is only populated once and needs application restart to reload the certs.
customizedTLSConfig := tls.Config{
Certificates: []tls.Certificate{},
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair("_test_data/localhost.crt", "_test_data/localhost.key")
if err != nil {
return nil, err
}
return &cert, nil
},
}

acceptor.SetTLSConfig(&customizedTLSConfig)
assert.NoError(t, acceptor.Start())
assert.Len(t, acceptor.listeners, 1)

conn, err := tls.Dial("tcp", "localhost:5001", &tls.Config{
InsecureSkipVerify: true,
})
require.NoError(t, err)
assert.NotNil(t, conn)
defer conn.Close()
}
24 changes: 19 additions & 5 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Acceptor struct {
sessionHostPort map[SessionID]int
listeners map[string]net.Listener
connectionValidator ConnectionValidator
tlsConfig *tls.Config
sessionFactory
}

Expand Down Expand Up @@ -81,9 +82,12 @@ func (a *Acceptor) Start() (err error) {
a.listeners[address] = nil
}

var tlsConfig *tls.Config
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
return
if a.tlsConfig == nil {
var tlsConfig *tls.Config
if tlsConfig, err = loadTLSConfig(a.settings.GlobalSettings()); err != nil {
return
}
a.tlsConfig = tlsConfig
}

var useTCPProxy bool
Expand All @@ -94,8 +98,8 @@ func (a *Acceptor) Start() (err error) {
}

for address := range a.listeners {
if tlsConfig != nil {
if a.listeners[address], err = tls.Listen("tcp", address, tlsConfig); err != nil {
if a.tlsConfig != nil {
if a.listeners[address], err = tls.Listen("tcp", address, a.tlsConfig); err != nil {
return
}
} else if a.listeners[address], err = net.Listen("tcp", address); err != nil {
Expand Down Expand Up @@ -421,3 +425,13 @@ LOOP:
func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) {
a.connectionValidator = validator
}

// SetTLSConfig allows the creator of the Acceptor to specify a fully customizable tls.Config of their choice,
// which will be used in the Start() method.
//
// Note: when the caller explicitly provides a tls.Config with this function,
// it takes precendent over TLS settings specified in the acceptor's settings.GlobalSettings(),
// meaning that the `settings.GlobalSettings()` object is not inspected or used for the creation of the tls.Config.
func (a *Acceptor) SetTLSConfig(tlsConfig *tls.Config) {
a.tlsConfig = tlsConfig
}

0 comments on commit 5ec1219

Please sign in to comment.