Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the delegation of determining the errors that can occur for an operation #304

Merged
merged 2 commits into from
Jun 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ private void generateOperationDeserializerMiddleware(GenerationContext context,
goWriter.write("");

Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorMessageCodeDeserializer);
context, operation, responseType, this::writeErrorMessageCodeDeserializer,
this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
deserializeDocumentBindingShapes.addAll(errorShapes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,15 @@
package software.amazon.smithy.go.codegen.integration;

import java.util.Collection;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
Expand All @@ -29,33 +35,36 @@
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;

public final class HttpProtocolGeneratorUtils {

private HttpProtocolGeneratorUtils() {}
private HttpProtocolGeneratorUtils() {
}

/**
* Generates a function that handles error deserialization by getting the error code then
* dispatching to the error-specific deserializer.
*
* <p>
* If the error code does not map to a known error, a generic error will be returned using
* the error code and error message discovered in the response.
*
* <p>
* The default error message and code are both "UnknownError".
*
* @param context The generation context.
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param context The generation context.
* @param operation The operation to generate for.
* @param responseType The response type for the HTTP protocol.
* @param errorMessageCodeGenerator A consumer that generates a snippet that sets the {@code errorCode}
* and {@code errorMessage} variables from the http response.
* @return A set of all error structure shapes for the operation that were dispatched to.
*/
static Set<StructureShape> generateErrorDispatcher(
static Set<StructureShape> generateErrorDispatcher(
GenerationContext context,
OperationShape operation,
Symbol responseType,
Consumer<GenerationContext> errorMessageCodeGenerator
Consumer<GenerationContext> errorMessageCodeGenerator,
BiFunction<GenerationContext, OperationShape, Map<String, ShapeId>> operationErrorsToShapes
) {
GoWriter writer = context.getWriter();
ServiceShape service = context.getService();
Expand All @@ -68,50 +77,49 @@ static Set<StructureShape> generateErrorDispatcher(
writer.addUseImports(SmithyGoDependency.SMITHY_MIDDLEWARE);
writer.openBlock("func $L(response $P, metadata *middleware.Metadata) error {", "}",
errorFunctionName, responseType, () -> {
writer.addUseImports(SmithyGoDependency.BYTES);
writer.addUseImports(SmithyGoDependency.IO);

// Copy the response body into a seekable type
writer.write("var errorBuffer bytes.Buffer");
writer.openBlock("if _, err := io.Copy(&errorBuffer, response.Body); err != nil {", "}", () -> {
writer.write("return &smithy.DeserializationError{Err: fmt.Errorf("
+ "\"failed to copy error response body, %w\", err)}");
});
writer.write("errorBody := bytes.NewReader(errorBuffer.Bytes())");
writer.write("");

// Set the default values for code and message.
writer.write("errorCode := \"UnknownError\"");
writer.write("errorMessage := errorCode");
writer.write("");

// Dispatch to the message/code generator to try to get the specific code and message.
errorMessageCodeGenerator.accept(context);

writer.openBlock("switch {", "}", () -> {
new TreeSet<>(operation.getErrors()).forEach(errorId -> {
StructureShape error = context.getModel().expectShape(errorId).asStructureShape().get();
errorShapes.add(error);
String errorDeserFunctionName = ProtocolGenerator.getErrorDeserFunctionName(
error, service, protocolName);
writer.addUseImports(SmithyGoDependency.STRINGS);
writer.openBlock("case strings.EqualFold($S, errorCode):", "", errorId.getName(service), () -> {
writer.write("return $L(response, errorBody)", errorDeserFunctionName);
writer.addUseImports(SmithyGoDependency.BYTES);
writer.addUseImports(SmithyGoDependency.IO);

// Copy the response body into a seekable type
writer.write("var errorBuffer bytes.Buffer");
writer.openBlock("if _, err := io.Copy(&errorBuffer, response.Body); err != nil {", "}", () -> {
writer.write("return &smithy.DeserializationError{Err: fmt.Errorf("
+ "\"failed to copy error response body, %w\", err)}");
});
});

// Create a generic error
writer.addUseImports(SmithyGoDependency.SMITHY);
writer.openBlock("default:", "", () -> {
writer.openBlock("genericError := &smithy.GenericAPIError{", "}", () -> {
writer.write("Code: errorCode,");
writer.write("Message: errorMessage,");
writer.write("errorBody := bytes.NewReader(errorBuffer.Bytes())");
writer.write("");

// Set the default values for code and message.
writer.write("errorCode := \"UnknownError\"");
writer.write("errorMessage := errorCode");
writer.write("");

// Dispatch to the message/code generator to try to get the specific code and message.
errorMessageCodeGenerator.accept(context);

writer.openBlock("switch {", "}", () -> {
operationErrorsToShapes.apply(context, operation).forEach((name, errorId) -> {
StructureShape error = context.getModel().expectShape(errorId).asStructureShape().get();
errorShapes.add(error);
String errorDeserFunctionName = ProtocolGenerator.getErrorDeserFunctionName(
error, service, protocolName);
writer.addUseImports(SmithyGoDependency.STRINGS);
writer.openBlock("case strings.EqualFold($S, errorCode):", "", name, () -> {
writer.write("return $L(response, errorBody)", errorDeserFunctionName);
});
});

// Create a generic error
writer.addUseImports(SmithyGoDependency.SMITHY);
writer.openBlock("default:", "", () -> {
writer.openBlock("genericError := &smithy.GenericAPIError{", "}", () -> {
writer.write("Code: errorCode,");
writer.write("Message: errorMessage,");
});
writer.write("return genericError");
});
});
writer.write("return genericError");
});
});
});
writer.write("");
}).write("");

return errorShapes;
}
Expand All @@ -136,4 +144,24 @@ public static boolean isShapeWithResponseBindings(Model model, Shape shape, Http
}
return false;
}

/**
* Returns a map of error names to their {@link ShapeId}.
*
* @param context the generation context
* @param operation the operation shape to retrieve errors for
* @return map of error names to {@link ShapeId}
*/
public static Map<String, ShapeId> getOperationErrors(GenerationContext context, OperationShape operation) {
return operation.getErrors().stream()
.collect(Collectors.toMap(
shapeId -> shapeId.getName(context.getService()),
Function.identity(),
(x, y) -> {
if (!x.equals(y)) {
throw new CodegenException(String.format("conflicting error shape ids: %s, %s", x, y));
}
return x;
}, TreeMap::new));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ private void generateOperationDeserializer(GenerationContext context, OperationS
writer.write("");

Set<StructureShape> errorShapes = HttpProtocolGeneratorUtils.generateErrorDispatcher(
context, operation, responseType, this::writeErrorMessageCodeDeserializer);
context, operation, responseType, this::writeErrorMessageCodeDeserializer,
this::getOperationErrors);
deserializingErrorShapes.addAll(errorShapes);
deserializingDocumentShapes.addAll(errorShapes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import software.amazon.smithy.codegen.core.CodegenException;
import software.amazon.smithy.codegen.core.SymbolProvider;
Expand Down Expand Up @@ -248,6 +249,17 @@ static String getDeserializeMiddlewareName(ShapeId operationShapeId, ServiceShap
+ operationShapeId.getName(service);
}

/**
* Returns a map of error names to their {@link ShapeId}.
*
* @param context the generation context
* @param operation the operation shape to retrieve errors for
* @return map of error names to {@link ShapeId}
*/
default Map<String, ShapeId> getOperationErrors(GenerationContext context, OperationShape operation) {
return HttpProtocolGeneratorUtils.getOperationErrors(context, operation);
}

/**
* Context object used for service serialization and deserialization.
*/
Expand Down