Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to Smithy 1.7.0 #289

Merged
merged 8 commits into from
May 5, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import software.amazon.smithy.go.codegen.knowledge.GoPointableIndex;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.CollectionShape;
import software.amazon.smithy.model.shapes.MapShape;
import software.amazon.smithy.model.shapes.MemberShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
Expand Down Expand Up @@ -313,6 +314,21 @@ public static CollectionShape expectCollectionShape(Shape shape) {
throw new CodegenException("expect shape " + shape.getId() + " to be Collection, was " + shape.getType());
}

/**
* Returns the shape unpacked as a MapShape. Throws and exception if the passed in
* shape is not a map.
*
* @param shape the map shape.
* @return The unpacked MapShape.
*/
public static MapShape expectMapShape(Shape shape) {
if (shape instanceof MapShape) {
return (MapShape) (shape);
}

throw new CodegenException("expect shape " + shape.getId() + " to be Map, was " + shape.getType());
}

/**
* Comparator to sort ShapeMember lists alphabetically, with required members first followed by optional members.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ final class CodegenVisitor extends ShapeVisitor.Default<Void> {
service = settings.getService(model);
LOGGER.info(() -> "Generating Go client for service " + service.getId());

SymbolProvider resolvedProvider = GoCodegenPlugin.createSymbolProvider(model, settings.getModuleName());
SymbolProvider resolvedProvider = GoCodegenPlugin.createSymbolProvider(model, settings);
for (GoIntegration integration : integrations) {
resolvedProvider = integration.decorateSymbolProvider(settings, model, resolvedProvider);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ public void execute(PluginContext context) {
/**
* Creates a Go symbol provider.
*
* @param model The model to generate symbols for.
* @param rootModuleName The name of the package root.
* @param model The model to generate symbols for.
* @param settings The Gosettings to use to create symbol provider
* @return Returns the created provider.
*/
public static SymbolProvider createSymbolProvider(Model model, String rootModuleName) {
return new SymbolVisitor(model, rootModuleName);
public static SymbolProvider createSymbolProvider(Model model, GoSettings settings) {
return new SymbolVisitor(model, settings);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,13 @@ final class SymbolVisitor implements SymbolProvider, ShapeVisitor<Symbol> {
private final ReservedWordSymbolProvider.Escaper errorMemberEscaper;
private final Map<ShapeId, ReservedWordSymbolProvider.Escaper> structureSpecificMemberEscapers = new HashMap<>();
private final GoPointableIndex pointableIndex;
private final GoSettings settings;


SymbolVisitor(Model model, String rootModuleName) {
SymbolVisitor(Model model, GoSettings settings) {
this.model = model;
this.rootModuleName = rootModuleName;
this.typesPackageName = rootModuleName + "/types";
this.settings = settings;
this.rootModuleName = settings.getModuleName();
this.typesPackageName = this.rootModuleName + "/types";
this.pointableIndex = GoPointableIndex.of(model);

// Reserve the generated names for union members, including the unknown case.
Expand Down Expand Up @@ -223,7 +224,8 @@ private String formatUnionMemberName(UnionShape union, MemberShape member) {
}

private String getDefaultShapeName(Shape shape) {
return StringUtils.capitalize(removeLeadingInvalidIdentCharacters(shape.getId().getName()));
ServiceShape serviceShape = model.expectShape(settings.getService(), ServiceShape.class);
return StringUtils.capitalize(removeLeadingInvalidIdentCharacters(shape.getId().getName(serviceShape)));
}

private String getDefaultMemberName(MemberShape shape) {
Expand Down Expand Up @@ -406,7 +408,8 @@ private Symbol createBigSymbol(Shape shape, String symbolName) {

@Override
public Symbol documentShape(DocumentShape shape) {
return symbolBuilderFor(shape, "Document", SmithyGoDependency.SMITHY)
String name = getDefaultShapeName(shape);
return symbolBuilderFor(shape, name, SmithyGoDependency.SMITHY)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SyntheticClone;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.model.shapes.CollectionShape;
import software.amazon.smithy.model.shapes.DocumentShape;
Expand Down Expand Up @@ -414,7 +415,12 @@ protected final void generateDeserFunction(
GoWriter writer = context.getWriter();

Symbol symbol = symbolProvider.toSymbol(shape);
String functionName = ProtocolGenerator.getDocumentDeserializerFunctionName(shape, context.getProtocolName());

String functionName = shape.hasTrait(SyntheticClone.class)
? ProtocolGenerator.getOperationDocumentDeserFuncName(
shape, context.getProtocolName())
: ProtocolGenerator.getDocumentDeserializerFunctionName(
shape, context.getService(), context.getProtocolName());
skotambkar marked this conversation as resolved.
Show resolved Hide resolved

String additionalArguments = getAdditionalArguments().entrySet().stream()
.map(entry -> String.format(", %s %s", entry.getKey(), entry.getValue()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SyntheticClone;
import software.amazon.smithy.go.codegen.integration.ProtocolGenerator.GenerationContext;
import software.amazon.smithy.model.shapes.CollectionShape;
import software.amazon.smithy.model.shapes.DocumentShape;
Expand Down Expand Up @@ -323,7 +324,12 @@ private void generateSerFunction(
GoWriter writer = context.getWriter();

Symbol symbol = symbolProvider.toSymbol(shape);
String functionName = ProtocolGenerator.getDocumentSerializerFunctionName(shape, context.getProtocolName());

String functionName = shape.hasTrait(SyntheticClone.class)
? ProtocolGenerator.getOperationDocumentSerFuncName(
shape, context.getProtocolName())
: ProtocolGenerator.getDocumentSerializerFunctionName(
shape, context.getService(), context.getProtocolName());
skotambkar marked this conversation as resolved.
Show resolved Hide resolved

String additionalArguments = getAdditionalSerArguments().entrySet().stream()
.map(entry -> String.format(", %s %s", entry.getKey(), entry.getValue()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ private boolean isRestBinding(HttpBinding.Location location) {
|| location == HttpBinding.Location.PREFIX_HEADERS
|| location == HttpBinding.Location.LABEL
|| location == HttpBinding.Location.QUERY
|| location == HttpBinding.Location.QUERY_PARAMS
|| location == HttpBinding.Location.RESPONSE_CODE;
}

Expand Down Expand Up @@ -665,14 +666,14 @@ private void writeHttpBindingSetter(
}

private void writeHttpBindingMember(
GenerationContext context,
final GenerationContext context,
HttpBinding binding
) {
GoWriter writer = context.getWriter();
Model model = context.getModel();
MemberShape memberShape = binding.getMember();
Shape targetShape = model.expectShape(memberShape.getTarget());
HttpBinding.Location location = binding.getLocation();
final Shape targetShape = model.expectShape(memberShape.getTarget());
final HttpBinding.Location location = binding.getLocation();

// return an error if member shape targets location label, but is unset.
if (location.equals(HttpBinding.Location.LABEL)) {
Expand Down Expand Up @@ -726,31 +727,85 @@ private void writeHttpBindingMember(
});
break;
case QUERY:
if (targetShape instanceof CollectionShape) {
MemberShape collectionMember = CodegenUtils.expectCollectionShape(targetShape)
.getMember();
writer.openBlock("for i := range $L {", "}", operand, () -> {
GoValueAccessUtils.writeIfZeroValue(context.getModel(), writer, collectionMember,
operand + "[i]", () -> {
writer.write("continue");
});
writeHttpBindingSetter(context, writer, collectionMember, location, operand + "[i]",
(w, s) -> {
w.writeInline("encoder.AddQuery($S).$L", locationName, s);
});
});
} else {
writeHttpBindingSetter(context, writer, memberShape, location, operand,
(w, s) -> w.writeInline(
"encoder.SetQuery($S).$L", locationName, s));
}
writeQueryBinding(context, memberShape, targetShape, operand,
location, locationName, "encoder", false);
break;
case QUERY_PARAMS:
MemberShape queryMapValueMemberShape = CodegenUtils.expectMapShape(targetShape).getValue();
Shape queryMapValueTargetShape = model.expectShape(queryMapValueMemberShape.getTarget());
MemberShape queryMapKeyMemberShape = CodegenUtils.expectMapShape(targetShape).getKey();
writer.openBlock("for qkey, qvalue := range $L {", "}", operand, () -> {
writer.write("if encoder.HasQuery(qkey) { continue }");
writeQueryBinding(context, queryMapKeyMemberShape, queryMapValueTargetShape,
"qvalue", location, "qkey", "encoder", true);
});
break;

default:
throw new CodegenException("unexpected http binding found");
}
});
}

/**
* Writes query bindings, as per the target shape. This method is shared
* between members modeled with Location.Query and Location.QueryParams.
* Precedence across Location.Query and Location.QueryParams is handled
* outside the scope of this function.
*
* @param context is the generation context
* @param memberShape is the member shape for which query is serialized
* @param targetShape is the target shape of the query member.
* This can either be string, or a list/set of string.
* @param operand is the member value accessor .
* @param location is the location of the member - can be Location.Query
* or Location.QueryParams.
* @param locationName is the key for which query is encoded.
* @param dest is the query encoder destination.
* @param isQueryParams boolean representing if Location used for query binding is
* QUERY_PARAMS.
*/
private void writeQueryBinding(
GenerationContext context,
MemberShape memberShape,
Shape targetShape,
String operand,
HttpBinding.Location location,
String locationName,
String dest,
boolean isQueryParams
) {
GoWriter writer = context.getWriter();

if (targetShape instanceof CollectionShape) {
MemberShape collectionMember = CodegenUtils.expectCollectionShape(targetShape)
.getMember();
writer.openBlock("for i := range $L {", "}", operand, () -> {
GoValueAccessUtils.writeIfZeroValue(context.getModel(), writer, collectionMember,
operand + "[i]", () -> {
writer.write("continue");
});
skotambkar marked this conversation as resolved.
Show resolved Hide resolved
writeHttpBindingSetter(context, writer, collectionMember, location, operand + "[i]",
(w, s) -> {
if (isQueryParams) {
w.writeInline("$L.AddQuery($L).$L", dest, locationName, s);
} else {
w.writeInline("$L.AddQuery($S).$L", dest, locationName, s);
}
});
skotambkar marked this conversation as resolved.
Show resolved Hide resolved
});
} else {
writeHttpBindingSetter(context, writer, memberShape, location, operand,
(w, s) -> {
if (isQueryParams) {
w.writeInline("$L.SetQuery($L).$L", dest, locationName, s);
} else {
w.writeInline("$L.SetQuery($S).$L", dest, locationName, s);
}
});
}
}

private void writeHeaderBinding(
GenerationContext context,
MemberShape memberShape,
Expand Down Expand Up @@ -1228,7 +1283,8 @@ private void addDocumentDeserializerBindingShapes(Model model, HttpBindingIndex

private void generateErrorDeserializer(GenerationContext context, StructureShape shape) {
GoWriter writer = context.getWriter();
String functionName = ProtocolGenerator.getErrorDeserFunctionName(shape, context.getProtocolName());
String functionName = ProtocolGenerator.getErrorDeserFunctionName(
shape, context.getService(), context.getProtocolName());
Symbol responseType = getApplicationProtocol().getResponseType();

writer.addUseImports(SmithyGoDependency.BYTES);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static Set<StructureShape> generateErrorDispatcher(
StructureShape error = context.getModel().expectShape(errorId).asStructureShape().get();
errorShapes.add(error);
String errorDeserFunctionName = ProtocolGenerator.getErrorDeserFunctionName(
error, context.getProtocolName());
error, context.getService(), context.getProtocolName());
writer.addUseImports(SmithyGoDependency.STRINGS);
writer.openBlock("case strings.EqualFold($S, errorCode):", "", errorId.getName(), () -> {
writer.write("return $L(response, errorBody)", errorDeserFunctionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Consumer;
Expand All @@ -38,6 +39,7 @@
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.model.traits.IdempotencyTokenTrait;
import software.amazon.smithy.protocoltests.traits.AppliesTo;
import software.amazon.smithy.protocoltests.traits.HttpMessageTestCase;
import software.amazon.smithy.utils.SmithyBuilder;

Expand Down Expand Up @@ -222,6 +224,11 @@ public void generateTestFunction(GoWriter writer) {
generateTestCaseParams(writer);
writer.openBlock("}{", "}", () -> {
for (T testCase : testCases) {
Optional<AppliesTo> appliesTo = testCase.getAppliesTo();
if (appliesTo.isPresent() && !(appliesTo.get().equals(AppliesTo.CLIENT))) {
continue;
}

testCase.getDocumentation().ifPresent(writer::writeDocs);
writer.openBlock("$S: {", "},", testCase.getId(), () -> {
generateTestCaseValues(writer, testCase);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ private void generateOperationDeserializer(GenerationContext context, OperationS

private void generateErrorDeserializer(GenerationContext context, StructureShape shape) {
GoWriter writer = context.getWriter();
String functionName = ProtocolGenerator.getErrorDeserFunctionName(shape, context.getProtocolName());
String functionName = ProtocolGenerator.getErrorDeserFunctionName(
shape, context.getService(), context.getProtocolName());
Symbol responseType = getApplicationProtocol().getResponseType();

writer.addUseImports(SmithyGoDependency.BYTES);
Expand Down
Loading