diff --git a/modules/rabbitmq/rabbitmq.go b/modules/rabbitmq/rabbitmq.go index 9fb28212e1..a6dec1a779 100644 --- a/modules/rabbitmq/rabbitmq.go +++ b/modules/rabbitmq/rabbitmq.go @@ -48,7 +48,7 @@ func (c *RabbitMQContainer) AmqpURL(ctx context.Context) (string, error) { // AmqpURL returns the URL for AMQPS clients. func (c *RabbitMQContainer) AmqpsURL(ctx context.Context) (string, error) { - endpoint, err := c.PortEndpoint(ctx, nat.Port(DefaultAMQPPort), "") + endpoint, err := c.PortEndpoint(ctx, nat.Port(DefaultAMQPSPort), "") if err != nil { return "", err } diff --git a/modules/rabbitmq/rabbitmq_test.go b/modules/rabbitmq/rabbitmq_test.go index 0c85c66607..7079379421 100644 --- a/modules/rabbitmq/rabbitmq_test.go +++ b/modules/rabbitmq/rabbitmq_test.go @@ -2,8 +2,12 @@ package rabbitmq_test import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "io" + "io/ioutil" + "path/filepath" "strings" "testing" @@ -42,6 +46,59 @@ func TestRunContainer_connectUsingAmqp(t *testing.T) { } } +func TestRunContainer_connectUsingAmqps(t *testing.T) { + ctx := context.Background() + + sslSettings := rabbitmq.SSLSettings{ + CACertFile: filepath.Join("testdata", "certs", "server_ca.pem"), + CertFile: filepath.Join("testdata", "certs", "server_cert.pem"), + KeyFile: filepath.Join("testdata", "certs", "server_key.pem"), + VerificationMode: rabbitmq.SSLVerificationModePeer, + FailIfNoCert: false, + VerificationDepth: 1, + } + + rabbitmqContainer, err := rabbitmq.RunContainer(ctx, rabbitmq.WithSSL(sslSettings)) + if err != nil { + t.Fatal(err) + } + + defer func() { + if err := rabbitmqContainer.Terminate(ctx); err != nil { + t.Fatal(err) + } + }() + + amqpsURL, err := rabbitmqContainer.AmqpsURL(ctx) + if err != nil { + t.Fatal(err) + } + + if !strings.HasPrefix(amqpsURL, "amqps") { + t.Fatal(fmt.Errorf("AMQPS Url should begin with `amqps`")) + } + + certs := x509.NewCertPool() + + pemData, err := ioutil.ReadFile(sslSettings.CACertFile) + if err != nil { + t.Fatal(err) + } + certs.AppendCertsFromPEM(pemData) + + amqpsConnection, err := amqp.DialTLS(amqpsURL, &tls.Config{InsecureSkipVerify: false, RootCAs: certs}) + if err != nil { + t.Fatal(err) + } + + if amqpsConnection.IsClosed() { + t.Fatal(fmt.Errorf("AMQPS Connection unexpectdely closed")) + } + if err = amqpsConnection.Close(); err != nil { + t.Fatal(err) + } +} + func TestRunContainer_withAllSettings(t *testing.T) { ctx := context.Background()