-
Notifications
You must be signed in to change notification settings - Fork 670
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Override ml endpoint with PredictEndpoint if set
- Loading branch information
1 parent
acfe56d
commit 827e3de
Showing
6 changed files
with
248 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
...va/software/amazon/smithy/aws/go/codegen/customization/MachineLearningCustomizations.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
package software.amazon.smithy.aws.go.codegen.customization; | ||
|
||
import java.util.List; | ||
import software.amazon.smithy.aws.traits.ServiceTrait; | ||
import software.amazon.smithy.codegen.core.SymbolProvider; | ||
import software.amazon.smithy.go.codegen.GoDelegator; | ||
import software.amazon.smithy.go.codegen.GoSettings; | ||
import software.amazon.smithy.go.codegen.GoWriter; | ||
import software.amazon.smithy.go.codegen.SmithyGoDependency; | ||
import software.amazon.smithy.go.codegen.SymbolUtils; | ||
import software.amazon.smithy.go.codegen.integration.GoIntegration; | ||
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar; | ||
import software.amazon.smithy.go.codegen.integration.ProtocolUtils; | ||
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin; | ||
import software.amazon.smithy.model.Model; | ||
import software.amazon.smithy.model.shapes.OperationShape; | ||
import software.amazon.smithy.model.shapes.ServiceShape; | ||
import software.amazon.smithy.model.shapes.Shape; | ||
import software.amazon.smithy.model.shapes.StructureShape; | ||
import software.amazon.smithy.utils.ListUtils; | ||
|
||
public class MachineLearningCustomizations implements GoIntegration { | ||
private static final String ADD_PREDICT_ENDPOINT = "AddPredictEndpointMiddleware"; | ||
private static final String ENDPOINT_ACCESSOR = "getPredictEndpoint"; | ||
|
||
@Override | ||
public byte getOrder() { | ||
// This needs to be run after the generic endpoint resolver gets added | ||
return 50; | ||
} | ||
|
||
@Override | ||
public List<RuntimeClientPlugin> getClientPlugins() { | ||
return ListUtils.of( | ||
RuntimeClientPlugin.builder() | ||
.operationPredicate(MachineLearningCustomizations::isPredict) | ||
.registerMiddleware(MiddlewareRegistrar.builder() | ||
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(ADD_PREDICT_ENDPOINT, | ||
AwsCustomGoDependency.MACHINE_LEARNING_CUSTOMIZATION).build()) | ||
.functionArguments(ListUtils.of( | ||
SymbolUtils.createValueSymbolBuilder(ENDPOINT_ACCESSOR).build() | ||
)) | ||
.build()) | ||
.build() | ||
); | ||
} | ||
|
||
@Override | ||
public void writeAdditionalFiles( | ||
GoSettings settings, | ||
Model model, | ||
SymbolProvider symbolProvider, | ||
GoDelegator goDelegator | ||
) { | ||
ServiceShape service = settings.getService(model); | ||
if (!isMachineLearning(model, service)) { | ||
return; | ||
} | ||
|
||
service.getAllOperations().stream() | ||
.filter(shapeId -> shapeId.getName().equalsIgnoreCase("Predict")) | ||
.findAny() | ||
.map(model::expectShape) | ||
.flatMap(Shape::asOperationShape) | ||
.ifPresent(operation -> { | ||
goDelegator.useShapeWriter(operation, writer -> writeEndpointAccessor( | ||
writer, model, symbolProvider, operation)); | ||
}); | ||
} | ||
|
||
private void writeEndpointAccessor( | ||
GoWriter writer, | ||
Model model, | ||
SymbolProvider symbolProvider, | ||
OperationShape operation | ||
) { | ||
StructureShape input = ProtocolUtils.expectInput(model, operation); | ||
writer.openBlock("func $L(input interface{}) (*string, error) {", "}", ENDPOINT_ACCESSOR, () -> { | ||
writer.write("in, ok := input.($P)", symbolProvider.toSymbol(input)); | ||
writer.openBlock("if !ok {", "}", () -> { | ||
writer.addUseImports(SmithyGoDependency.SMITHY); | ||
writer.addUseImports(SmithyGoDependency.FMT); | ||
writer.write("return nil, &smithy.SerializationError{Err: fmt.Errorf(" | ||
+ "\"expected $P, but was %T\", input)}", symbolProvider.toSymbol(input)); | ||
}); | ||
writer.write("return in.PredictEndpoint, nil"); | ||
}); | ||
} | ||
|
||
private static boolean isPredict(Model model, ServiceShape service, OperationShape operation) { | ||
return isMachineLearning(model, service) && operation.getId().getName().equalsIgnoreCase("Predict"); | ||
} | ||
|
||
private static boolean isMachineLearning(Model model, ServiceShape service) { | ||
return service.expectTrait(ServiceTrait.class).getSdkId().equalsIgnoreCase("Machine Learning"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
61 changes: 61 additions & 0 deletions
61
service/machinelearning/internal/customizations/predictendpoint.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package customizations | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"github.com/awslabs/smithy-go" | ||
"github.com/awslabs/smithy-go/middleware" | ||
smithyhttp "github.com/awslabs/smithy-go/transport/http" | ||
"net/url" | ||
) | ||
|
||
type fetchPredictEndpointFunc func(interface{}) (*string, error) | ||
|
||
// AddPredictEndpointMiddleware adds the middleware required to set the endpoint | ||
// based on Predict's PredictEndpoint input member. | ||
func AddPredictEndpointMiddleware(stack *middleware.Stack, endpoint fetchPredictEndpointFunc) { | ||
stack.Serialize.Insert(&predictEndpointMiddleware{}, "ResolveEndpoint", middleware.After) | ||
} | ||
|
||
// predictEndpointMiddleware rewrites the endpoint with whatever is specified in the | ||
// operation input if it is non-nil and non-empty. | ||
type predictEndpointMiddleware struct{ | ||
fetchPredictEndpoint fetchPredictEndpointFunc | ||
} | ||
|
||
// ID returns the id for the middleware. | ||
func (*predictEndpointMiddleware) ID() string { return "MachineLearning:PredictEndpoint" } | ||
|
||
// HandleSerialize implements the SerializeMiddleware interface. | ||
func (m *predictEndpointMiddleware) HandleSerialize( | ||
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler, | ||
) ( | ||
out middleware.SerializeOutput, metadata middleware.Metadata, err error, | ||
) { | ||
req, ok := in.Request.(*smithyhttp.Request) | ||
if !ok { | ||
return out, metadata, &smithy.SerializationError{ | ||
Err: fmt.Errorf("unknown request type %T", in.Request), | ||
} | ||
} | ||
|
||
endpoint, err := m.fetchPredictEndpoint(in.Parameters) | ||
if err != nil { | ||
return out, metadata, &smithy.SerializationError{ | ||
Err: fmt.Errorf("failed to fetch PredictEndpoint value, %v", err), | ||
} | ||
} | ||
|
||
if endpoint != nil && len(*endpoint) != 0 { | ||
uri, err := url.Parse(*endpoint) | ||
if err != nil { | ||
return out, metadata, &smithy.SerializationError{ | ||
Err: fmt.Errorf("unable to parse predict endpoint, %v", err), | ||
} | ||
} | ||
req.URL = uri | ||
in.Request = req | ||
} | ||
|
||
return next.HandleSerialize(ctx, in) | ||
} |
75 changes: 75 additions & 0 deletions
75
service/machinelearning/internal/customizations/predictendpoint_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package customizations | ||
|
||
import ( | ||
"context" | ||
"github.com/awslabs/smithy-go/middleware" | ||
"github.com/awslabs/smithy-go/ptr" | ||
smithyhttp "github.com/awslabs/smithy-go/transport/http" | ||
"strings" | ||
"testing" | ||
) | ||
|
||
func TestPredictEndpointMiddleware(t *testing.T) { | ||
cases := map[string]struct { | ||
PredictEndpoint *string | ||
ExpectedEndpoint string | ||
ExpectedErr string | ||
}{ | ||
"nil endpoint": {}, | ||
"empty endpoint": { | ||
PredictEndpoint: ptr.String(""), | ||
}, | ||
"invalid endpoint": { | ||
PredictEndpoint: ptr.String("::::::::"), | ||
ExpectedErr: "unable to parse", | ||
}, | ||
"valid endpoint": { | ||
PredictEndpoint: ptr.String("https://example.amazonaws.com/"), | ||
ExpectedEndpoint: "https://example.amazonaws.com/", | ||
}, | ||
} | ||
|
||
for name, c := range cases { | ||
t.Run(name, func(t *testing.T) { | ||
m := &predictEndpointMiddleware{ | ||
fetchPredictEndpoint: func(i interface{}) (*string, error) { | ||
return c.PredictEndpoint, nil | ||
}, | ||
} | ||
_, _, err := m.HandleSerialize(context.Background(), | ||
middleware.SerializeInput{ | ||
Request: smithyhttp.NewStackRequest(), | ||
}, | ||
middleware.SerializeHandlerFunc( | ||
func(ctx context.Context, input middleware.SerializeInput) ( | ||
output middleware.SerializeOutput, metadata middleware.Metadata, err error, | ||
) { | ||
|
||
req, ok := input.Request.(*smithyhttp.Request) | ||
if !ok || req == nil { | ||
t.Fatalf("expect smithy request, got %T", input.Request) | ||
} | ||
|
||
if c.ExpectedEndpoint != req.URL.String() { | ||
t.Errorf("expected url to be `%v`, but was `%v`", c.ExpectedEndpoint, req.URL.String()) | ||
} | ||
|
||
return output, metadata, err | ||
}), | ||
) | ||
if len(c.ExpectedErr) != 0 { | ||
if err == nil { | ||
t.Fatalf("expect error, got none") | ||
} | ||
if e, a := c.ExpectedErr, err.Error(); !strings.Contains(a, e) { | ||
t.Fatalf("expect error to contain %v, got %v", e, a) | ||
} | ||
} else { | ||
if err != nil { | ||
t.Fatalf("expect no error, got %v", err) | ||
} | ||
} | ||
}) | ||
} | ||
|
||
} |