Skip to content

Commit

Permalink
Add support for additional AMQP URI query parameters
Browse files Browse the repository at this point in the history
https://www.rabbitmq.com/docs/uri-query-parameters specifies several parameters that are used in this library, but not yet supported in URIs.

This commit adds support for the following parameters:
auth_mechanism
heartbeat
connection_timeout
channel_max
  • Loading branch information
vilius-g committed Feb 29, 2024
1 parent e61228a commit 2e14735
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 11 deletions.
30 changes: 29 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,39 @@ func DialConfig(url string, config Config) (*Connection, error) {
config.Vhost = uri.Vhost
}

if config.Heartbeat == 0 {
config.Heartbeat = time.Duration(uri.Heartbeat) * time.Second
}

if config.ChannelMax == 0 {
config.ChannelMax = uri.ChannelMax
}

connectionTimeout := defaultConnectionTimeout
if uri.ConnectionTimeout != 0 {
connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond
}

if config.SASL == nil && uri.AuthMechanism != nil {
for _, identifier := range uri.AuthMechanism {
switch strings.ToUpper(identifier) {
case "PLAIN":
config.SASL = append(config.SASL, uri.PlainAuth())
case "AMQPLAIN":
config.SASL = append(config.SASL, uri.AMQPlainAuth())
case "EXTERNAL":
config.SASL = append(config.SASL, &ExternalAuth{})
default:
return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier)
}
}
}

addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10))

dialer := config.Dial
if dialer == nil {
dialer = DefaultDial(defaultConnectionTimeout)
dialer = DefaultDial(connectionTimeout)
}

conn, err = dialer("tcp", addr)
Expand Down
50 changes: 40 additions & 10 deletions uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package amqp091

import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
Expand All @@ -32,16 +33,20 @@ var defaultURI = URI{

// URI represents a parsed AMQP URI string.
type URI struct {
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
AuthMechanism []string
Heartbeat int
ConnectionTimeout int
ChannelMax int
}

// ParseURI attempts to parse the given AMQP URI according to the spec.
Expand Down Expand Up @@ -134,6 +139,31 @@ func ParseURI(uri string) (URI, error) {
builder.KeyFile = params.Get("keyfile")
builder.CACertFile = params.Get("cacertfile")
builder.ServerName = params.Get("server_name_indication")
builder.AuthMechanism = params["auth_mechanism"]

if params.Has("heartbeat") {
value, err := strconv.Atoi(params.Get("heartbeat"))
if err != nil {
return builder, fmt.Errorf("heartbeat is not an integer: %v", err)
}
builder.Heartbeat = value
}

if params.Has("connection_timeout") {
value, err := strconv.Atoi(params.Get("connection_timeout"))
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ConnectionTimeout = value
}

if params.Has("channel_max") {
value, err := strconv.Atoi(params.Get("channel_max"))
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ChannelMax = value
}

return builder, nil
}
Expand Down
21 changes: 21 additions & 0 deletions uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package amqp091

import (
"reflect"
"testing"
)

Expand Down Expand Up @@ -388,3 +389,23 @@ func TestURITLSConfig(t *testing.T) {
t.Fatal("Server name not set")
}
}

func TestURIParameters(t *testing.T) {
url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8"
uri, err := ParseURI(url)
if err != nil {
t.Fatal("Could not parse")
}
if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) {
t.Fatal("AuthMechanism not set")
}
if uri.Heartbeat != 2 {
t.Fatal("Heartbeat not set")
}
if uri.ConnectionTimeout != 5000 {
t.Fatal("ConnectionTimeout not set")
}
if uri.ChannelMax != 8 {
t.Fatal("ChannelMax name not set")
}
}

0 comments on commit 2e14735

Please sign in to comment.