Skip to content

Commit

Permalink
Remote codec-endpoint cherry-pick (#420) (#421)
Browse files Browse the repository at this point in the history
This is #420 cherry-picked onto
`main`. We'll release either from `main` or `v0.10.8.rc` but either way
we want the chanes in `main`.

This replaces #384
  • Loading branch information
dandavison authored Jan 29, 2024
1 parent 8aefcd1 commit 41b9a3b
Show file tree
Hide file tree
Showing 12 changed files with 197 additions and 201 deletions.
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ linters-settings:
disabled: true
- name: file-header
disabled: true
- name: flag-parameter
disabled: true
- name: function-length
disabled: true
- name: imports-blacklist
Expand Down
8 changes: 3 additions & 5 deletions activity/activity_commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import (

"github.com/temporalio/cli/client"
"github.com/temporalio/cli/common"
"github.com/temporalio/cli/dataconverter"
"github.com/temporalio/tctl-kit/pkg/color"
"github.com/urfave/cli/v2"
failurepb "go.temporal.io/api/failure/v1"
"go.temporal.io/api/workflowservice/v1"
"go.temporal.io/sdk/converter"
)

// CompleteActivity completes an Activity
Expand All @@ -27,9 +27,7 @@ func CompleteActivity(c *cli.Context) error {
ctx, cancel := common.NewContext(c)
defer cancel()

// TODO: This should use common.CustomDataConverter once the plugin interface
// supports the full DataConverter API.
resultPayloads, _ := dataconverter.DefaultDataConverter().ToPayloads(result)
resultPayloads, _ := converter.GetDefaultDataConverter().ToPayloads(result)

frontendClient := client.Factory(c.App).FrontendClient(c)
_, err = frontendClient.RespondActivityTaskCompletedById(ctx, &workflowservice.RespondActivityTaskCompletedByIdRequest{
Expand Down Expand Up @@ -68,7 +66,7 @@ func FailActivity(c *cli.Context) error {
ctx, cancel := common.NewContext(c)
defer cancel()

detailsPayloads, _ := dataconverter.DefaultDataConverter().ToPayloads(detail)
detailsPayloads, _ := converter.GetDefaultDataConverter().ToPayloads(detail)

frontendClient := client.Factory(c.App).FrontendClient(c)
_, err = frontendClient.RespondActivityTaskFailedById(ctx, &workflowservice.RespondActivityTaskFailedByIdRequest{
Expand Down
113 changes: 113 additions & 0 deletions client/codec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// The MIT License
//
// Copyright (c) 2021 Temporal Technologies Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

// Vendored code from sdk-go.

package client

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"

"github.com/gogo/protobuf/jsonpb"
commonpb "go.temporal.io/api/common/v1"
"go.temporal.io/sdk/converter"
)

// RemotePayloadCodecOptions are options for RemotePayloadCodec.
// Client is optional.
type RemotePayloadCodecOptions struct {
Endpoint string
ModifyRequest func(*http.Request) error
Client http.Client
}

type remotePayloadCodec struct {
options RemotePayloadCodecOptions
}

const remotePayloadCodecEncodePath = "/encode"
const remotePayloadCodecDecodePath = "/decode"

// NewRemotePayloadCodec creates a PayloadCodec using the remote endpoint configured by RemotePayloadCodecOptions.
func NewRemotePayloadCodec(options RemotePayloadCodecOptions) converter.PayloadCodec {
return &remotePayloadCodec{options}
}

// Encode uses the remote payload codec endpoint to encode payloads.
func (pc *remotePayloadCodec) Encode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
return pc.encodeOrDecode(pc.options.Endpoint+remotePayloadCodecEncodePath, payloads)
}

// Decode uses the remote payload codec endpoint to decode payloads.
func (pc *remotePayloadCodec) Decode(payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
return pc.encodeOrDecode(pc.options.Endpoint+remotePayloadCodecDecodePath, payloads)
}

func (pc *remotePayloadCodec) encodeOrDecode(endpoint string, payloads []*commonpb.Payload) ([]*commonpb.Payload, error) {
requestPayloads, err := json.Marshal(commonpb.Payloads{Payloads: payloads})
if err != nil {
return payloads, fmt.Errorf("unable to marshal payloads: %w", err)
}

req, err := http.NewRequest(http.MethodPost, endpoint, bytes.NewReader(requestPayloads))
if err != nil {
return payloads, fmt.Errorf("unable to build request: %w", err)
}

req.Header.Set("Content-Type", "application/json")

if pc.options.ModifyRequest != nil {
err = pc.options.ModifyRequest(req)
if err != nil {
return payloads, err
}
}

response, err := pc.options.Client.Do(req)
if err != nil {
return payloads, err
}
defer func() { _ = response.Body.Close() }()

if response.StatusCode == 200 {
bs, err := io.ReadAll(response.Body)
if err != nil {
return payloads, fmt.Errorf("failed to read response body: %w", err)
}
var resultPayloads commonpb.Payloads
err = jsonpb.UnmarshalString(string(bs), &resultPayloads)
if err != nil {
return payloads, fmt.Errorf("unable to unmarshal payloads: %w", err)
}
if len(payloads) != len(resultPayloads.Payloads) {
return payloads, fmt.Errorf("received %d payloads from remote codec, expected %d", len(resultPayloads.Payloads), len(payloads))
}
return resultPayloads.Payloads, nil
}

message, _ := io.ReadAll(response.Body)
return payloads, fmt.Errorf("%s: %s", http.StatusText(response.StatusCode), message)
}
55 changes: 44 additions & 11 deletions client/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ import (

"github.com/gogo/status"
"github.com/temporalio/cli/common"
"github.com/temporalio/cli/dataconverter"
"github.com/temporalio/cli/headersprovider"
"github.com/urfave/cli/v2"
"go.temporal.io/api/operatorservice/v1"
"go.temporal.io/api/serviceerror"
"go.temporal.io/api/workflowservice/v1"
sdkclient "go.temporal.io/sdk/client"
"go.temporal.io/sdk/converter"
"go.temporal.io/server/common/auth"
"go.temporal.io/server/common/log"
"go.temporal.io/server/common/log/tag"
Expand Down Expand Up @@ -125,13 +125,23 @@ func (b *clientFactory) SDKClient(c *cli.Context, namespace string) sdkclient.Cl
b.logger.Fatal("Failed to configure TLS for SDK client", tag.Error(err))
}

dialOptions := []grpc.DialOption{}
if codecEndpoint := c.String(common.FlagCodecEndpoint); codecEndpoint != "" {
interceptor, err := newPayloadCodecGRPCClientInterceptor(c, codecEndpoint)
if err != nil {
b.logger.Fatal("Failed to configure payload codec interceptor", tag.Error(err))
}
dialOptions = append(dialOptions, grpc.WithChainUnaryInterceptor(interceptor))
}

sdkClient, err := sdkclient.Dial(sdkclient.Options{
HostPort: hostPort,
Namespace: namespace,
Logger: log.NewSdkLogger(b.logger),
Identity: common.GetCliIdentity(),
ConnectionOptions: sdkclient.ConnectionOptions{
TLS: tlsConfig,
DialOptions: dialOptions,
TLS: tlsConfig,
},
HeadersProvider: headersprovider.GetCurrent(),
})
Expand All @@ -142,6 +152,31 @@ func (b *clientFactory) SDKClient(c *cli.Context, namespace string) sdkclient.Cl
return sdkClient
}

func newPayloadCodecGRPCClientInterceptor(c *cli.Context, codecEndpoint string) (grpc.UnaryClientInterceptor, error) {
namespace := c.String(common.FlagNamespace)
codecAuth := c.String(common.FlagCodecAuth)
codecEndpoint = strings.ReplaceAll(codecEndpoint, "{namespace}", namespace)

payloadCodec := NewRemotePayloadCodec(
RemotePayloadCodecOptions{
Endpoint: codecEndpoint,
ModifyRequest: func(req *http.Request) error {
req.Header.Set("X-Namespace", namespace)
if codecAuth != "" {
req.Header.Set("Authorization", codecAuth)
}

return nil
},
},
)
return converter.NewPayloadCodecGRPCClientInterceptor(
converter.PayloadCodecGRPCClientInterceptorOptions{
Codecs: []converter.PayloadCodec{payloadCodec},
},
)
}

// HealthClient builds a health client.
func (b *clientFactory) HealthClient(c *cli.Context) healthpb.HealthClient {
connection, _ := b.createGRPCConnection(c)
Expand All @@ -150,15 +185,6 @@ func (b *clientFactory) HealthClient(c *cli.Context) healthpb.HealthClient {
}

func configureSDK(ctx *cli.Context) error {
endpoint := ctx.String(common.FlagCodecEndpoint)
if endpoint != "" {
dataconverter.SetRemoteEndpoint(
endpoint,
ctx.String(common.FlagNamespace),
ctx.String(common.FlagCodecAuth),
)
}

md, err := common.SplitKeyValuePairs(ctx.StringSlice(common.FlagMetadata))
if err != nil {
return err
Expand Down Expand Up @@ -211,6 +237,13 @@ func (b *clientFactory) createGRPCConnection(c *cli.Context) (*grpc.ClientConn,
errorInterceptor(),
headersProviderInterceptor(headersprovider.GetCurrent()),
}
if codecEndpoint := c.String(common.FlagCodecEndpoint); codecEndpoint != "" {
interceptor, err := newPayloadCodecGRPCClientInterceptor(c, codecEndpoint)
if err != nil {
b.logger.Fatal("Failed to configure payload codec interceptor", tag.Error(err))
}
interceptors = append(interceptors, interceptor)
}

dialOpts := []grpc.DialOption{
grpcSecurityOptions,
Expand Down
21 changes: 12 additions & 9 deletions common/stringify/stringify.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ const (
maxWordLength = 120 // if text length is larger than maxWordLength, it will be inserted spaces
)

func AnyToString(val interface{}, printFully bool, maxFieldLength int, dc converter.DataConverter) string {
//revive:disable:cognitive-complexity
//revive:disable:cyclomatic
func AnyToString(val interface{}, printFully bool, maxFieldLength int) string {
dc := converter.GetDefaultDataConverter()
v := reflect.ValueOf(val)
if val == nil || (v.Kind() == reflect.Ptr && v.IsNil()) {
return ""
Expand Down Expand Up @@ -65,9 +68,9 @@ func AnyToString(val interface{}, printFully bool, maxFieldLength int, dc conver
return ""
case reflect.Slice:
// All but []byte which is already handled.
return sliceToString(v, printFully, maxFieldLength, dc)
return sliceToString(v, printFully, maxFieldLength)
case reflect.Ptr:
return AnyToString(v.Elem().Interface(), printFully, maxFieldLength, dc)
return AnyToString(v.Elem().Interface(), printFully, maxFieldLength)
case reflect.Map:
type keyValuePair struct {
key string
Expand All @@ -82,11 +85,11 @@ func AnyToString(val interface{}, printFully bool, maxFieldLength int, dc conver
if !mapKey.CanInterface() || !mapVal.CanInterface() {
continue
}
mapKeyStr := AnyToString(mapKey.Interface(), true, 0, dc)
mapKeyStr := AnyToString(mapKey.Interface(), true, 0)
if mapKeyStr == "" {
continue
}
mapValStr := AnyToString(mapVal.Interface(), true, 0, dc)
mapValStr := AnyToString(mapVal.Interface(), true, 0)
if mapValStr == "" {
continue
}
Expand Down Expand Up @@ -127,7 +130,7 @@ func AnyToString(val interface{}, printFully bool, maxFieldLength int, dc conver
}

fieldName := t.Field(i).Name
fieldStr := AnyToString(f.Interface(), printFully, maxFieldLength, dc)
fieldStr := AnyToString(f.Interface(), printFully, maxFieldLength)
if fieldStr == "" {
continue
}
Expand Down Expand Up @@ -168,17 +171,17 @@ func AnyToString(val interface{}, printFully bool, maxFieldLength int, dc conver
}
}

func sliceToString(slice reflect.Value, printFully bool, maxFieldLength int, dc converter.DataConverter) string {
func sliceToString(slice reflect.Value, printFully bool, maxFieldLength int) string {
var b strings.Builder
b.WriteRune('[')
for i := 0; i < slice.Len(); i++ {
if i == 0 || printFully {
b.WriteString(AnyToString(slice.Index(i).Interface(), printFully, maxFieldLength, dc))
_, _ = b.WriteString(AnyToString(slice.Index(i).Interface(), printFully, maxFieldLength))
if i < slice.Len()-1 {
b.WriteRune(',')
}
if !printFully && slice.Len() > 1 {
b.WriteString(fmt.Sprintf("...%d more]", slice.Len()-1))
_, _ = b.WriteString(fmt.Sprintf("...%d more]", slice.Len()-1))
return b.String()
}
}
Expand Down
Loading

0 comments on commit 41b9a3b

Please sign in to comment.