Skip to content

Commit

Permalink
feat(client): support auto enable tls by schema (#585)
Browse files Browse the repository at this point in the history
### Motivation

Same as streamnative/oxia-java#192. Support
auto-enable TLS on the client.
  • Loading branch information
mattisonchao authored Nov 25, 2024
1 parent 79fea6a commit b534ac8
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
30 changes: 24 additions & 6 deletions common/client_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"crypto/tls"
"io"
"log/slog"
"strings"
"sync"
"time"

Expand All @@ -37,6 +38,7 @@ import (
)

const DefaultRpcTimeout = 30 * time.Second
const AddressSchemaTLS = "tls://"

type ClientPool interface {
io.Closer
Expand Down Expand Up @@ -170,11 +172,7 @@ func (cp *clientPool) newConnection(target string) (*grpc.ClientConn, error) {
slog.String("server_address", target),
)

// tls configure
tcs := insecure.NewCredentials()
if cp.tls != nil {
tcs = credentials.NewTLS(cp.tls)
}
tcs := cp.getTransportCredential(target)

options := []grpc.DialOption{
grpc.WithTransportCredentials(tcs),
Expand All @@ -184,14 +182,34 @@ func (cp *clientPool) newConnection(target string) (*grpc.ClientConn, error) {
if cp.authentication != nil {
options = append(options, grpc.WithPerRPCCredentials(cp.authentication))
}
cnx, err := grpc.NewClient(target, options...)
cnx, err := grpc.NewClient(cp.getActualAddress(target), options...)
if err != nil {
return nil, errors.Wrapf(err, "error connecting to %s", target)
}

return cnx, nil
}

func (*clientPool) getActualAddress(target string) string {
if strings.HasPrefix(target, AddressSchemaTLS) {
after, _ := strings.CutPrefix(target, AddressSchemaTLS)
return after
}
return target
}

//nolint:gosec
func (cp *clientPool) getTransportCredential(target string) credentials.TransportCredentials {
tcs := insecure.NewCredentials()
if strings.HasPrefix(target, AddressSchemaTLS) {
tcs = credentials.NewTLS(&tls.Config{})
}
if cp.tls != nil {
tcs = credentials.NewTLS(cp.tls)
}
return tcs
}

func GetPeer(ctx context.Context) string {
p, ok := peer.FromContext(ctx)
if !ok {
Expand Down
42 changes: 42 additions & 0 deletions common/client_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright 2024 StreamNative, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package common

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestClientPool_GetActualAddress(t *testing.T) {
pool := NewClientPool(nil, nil)
poolInstance := pool.(*clientPool)

address := poolInstance.getActualAddress("tls://xxxxaa:6648")
assert.Equal(t, "xxxxaa:6648", address)

actualAddress := poolInstance.getActualAddress("xxxxaaa:6649")
assert.Equal(t, "xxxxaaa:6649", actualAddress)
}

func TestClientPool_GetTransportCredential(t *testing.T) {
pool := NewClientPool(nil, nil)
poolInstance := pool.(*clientPool)

credential := poolInstance.getTransportCredential("tls://xxxxaa:6648")
assert.Equal(t, "tls", credential.Info().SecurityProtocol)

credential = poolInstance.getTransportCredential("xxxxaaa:6649")
assert.Equal(t, "insecure", credential.Info().SecurityProtocol)
}

0 comments on commit b534ac8

Please sign in to comment.