Skip to content

Commit

Permalink
Merge pull request #408 from wneessen/feature/404_unix-domain-socket-…
Browse files Browse the repository at this point in the history
…support

Unix domain socket support
  • Loading branch information
wneessen authored Jan 8, 2025
2 parents 42ce0bf + fa129a2 commit eb48854
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 10 deletions.
22 changes: 20 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ type (
// user represents a username used for the SMTP authentication.
user string

// useUnixSocket indicates that a connection is established via a Unix Domain Socket instead of TCP
useUnixSocket bool

// useSSL indicates whether to use SSL/TLS encryption for network communication.
//
// https://datatracker.ietf.org/doc/html/rfc8314
Expand Down Expand Up @@ -288,6 +291,12 @@ func NewClient(host string, opts ...Option) (*Client, error) {
}
}

// We allow connecting to a UNIX Domain Socket
if strings.HasPrefix(c.host, "unix://") {
c.useUnixSocket = true
c.host = strings.TrimPrefix(c.host, "unix://")
}

// Some settings in a Client cannot be empty/unset
if c.host == "" {
return c, ErrNoHostname
Expand Down Expand Up @@ -722,6 +731,9 @@ func (c *Client) TLSPolicy() string {
// Returns:
// - A string representing the server address in the format "host:port".
func (c *Client) ServerAddr() string {
if c.useUnixSocket {
return c.host
}
return fmt.Sprintf("%s:%d", c.host, c.port)
}

Expand Down Expand Up @@ -1004,8 +1016,14 @@ func (c *Client) DialToSMTPClientWithContext(ctxDial context.Context) (*smtp.Cli
dialContextFunc = tlsDialer.DialContext
}
}
connection, err := dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {

network := "tcp"
if c.useUnixSocket {
network = "unix"
}

connection, err := dialContextFunc(ctx, network, c.ServerAddr())
if err != nil && !c.useUnixSocket && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
connection, err = dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
Expand Down
94 changes: 86 additions & 8 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,19 @@ func TestNewClient(t *testing.T) {
})
}
})
t.Run("NewClient on Unix Domain Socket", func(t *testing.T) {
client, err := NewClient("unix:///tmp/mail.sock")
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
if !client.useUnixSocket {
t.Error("Expected useUnixSocket flag to be set to true")
}
if !strings.EqualFold(client.host, "/tmp/mail.sock") {
t.Errorf("expected host to be set to unix socket path, expected: %s, got: %s", "/tmp/mail.sock",
client.host)
}
})
}

func TestClient_TLSPolicy(t *testing.T) {
Expand Down Expand Up @@ -2816,6 +2829,64 @@ func TestClient_DialToSMTPClientWithContext(t *testing.T) {
t.Fatal("expected connection to fake to fail")
}
})
t.Run("dial to Unix domain socket", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
PortAdder.Add(1)
serverPort := int(TestServerPortBase + PortAdder.Load())
featureSet := "250-8BITMIME\r\n250-DSN\r\n250 SMTPUTF8"
props := &serverProps{
FeatureSet: featureSet,
ListenPort: serverPort,
UnixSocket: true,
}
go func() {
if err := simpleSMTPServer(ctx, t, props); err != nil {
t.Errorf("failed to start test server: %s", err)
return
}
}()
time.Sleep(time.Millisecond * 30)

ctxDial, cancelDial := context.WithTimeout(ctx, time.Millisecond*500)
t.Cleanup(cancelDial)
t.Cleanup(func() {
if err := os.RemoveAll(props.UnixSocketPath); err != nil {
t.Errorf("failed to remove unix socket: %s", err)
}
})

client, err := NewClient("unix://"+props.UnixSocketPath+"/server.sock", WithTLSPolicy(NoTLS))
if err != nil {
t.Fatalf("failed to create new client: %s", err)
}
smtpClient, err := client.DialToSMTPClientWithContext(ctxDial)
if err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
t.Skip("failed to connect to the test server due to timeout")
}
t.Fatalf("failed to connect to test server: %s", err)
}
t.Cleanup(func() {
if err := client.CloseWithSMTPClient(smtpClient); err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
t.Skip("failed to close the test server connection due to timeout")
}
t.Errorf("failed to close client: %s", err)
}
})
if smtpClient == nil {
t.Fatal("expected SMTP client, got nil")
}
if !smtpClient.HasConnection() {
t.Fatal("expected connection on smtp client")
}
if ok, _ := smtpClient.Extension("DSN"); !ok {
t.Error("expected DSN extension but it was not found")
}
})
}

func TestClient_sendSingleMsg(t *testing.T) {
Expand Down Expand Up @@ -3837,6 +3908,8 @@ type serverProps struct {
SSLListener bool
IsTLS bool
SupportDSN bool
UnixSocket bool
UnixSocketPath string
}

// simpleSMTPServer starts a simple TCP server that resonds to SMTP commands.
Expand All @@ -3850,18 +3923,23 @@ func simpleSMTPServer(ctx context.Context, t *testing.T, props *serverProps) err

var listener net.Listener
var err error
if props.SSLListener {
keypair, err := tls.X509KeyPair(localhostCert, localhostKey)
if err != nil {
return fmt.Errorf("failed to read TLS keypair: %w", err)
switch {
case props.UnixSocket:
path, perr := os.MkdirTemp("", "go-mail-server-*")
if perr != nil {
return fmt.Errorf("failed to create temp directory: %w", perr)
}
listener, err = net.Listen("unix", path+"/server.sock")
props.UnixSocketPath = path
case props.SSLListener:
keypair, kerr := tls.X509KeyPair(localhostCert, localhostKey)
if kerr != nil {
return fmt.Errorf("failed to read TLS keypair: %w", kerr)
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{keypair}}
listener, err = tls.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort),
tlsConfig)
if err != nil {
t.Fatalf("failed to create TLS listener: %s", err)
}
} else {
default:
listener, err = net.Listen(TestServerProto, fmt.Sprintf("%s:%d", TestServerAddr, props.ListenPort))
}
if err != nil {
Expand Down

0 comments on commit eb48854

Please sign in to comment.