Skip to content

Commit

Permalink
[Rust-Axum] Support Authentication (Cookie, API Key In Header) (#20017)
Browse files Browse the repository at this point in the history
* [Rust-Axum] Support Cookie Authentication & API Key In Header

* Fix

* Fix

* Fix Header Params & Responses
  • Loading branch information
linxGnu authored Nov 6, 2024
1 parent 06547b7 commit cded99c
Show file tree
Hide file tree
Showing 30 changed files with 2,677 additions and 75 deletions.
11 changes: 11 additions & 0 deletions bin/configs/manual/rust-axum-apikey-auths.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
generatorName: rust-axum
outputDir: samples/server/petstore/rust-axum/output/apikey-auths
inputSpec: modules/openapi-generator/src/test/resources/3_0/jetbrains/CheckoutBasicBearerCookieQueryHeaderBasicBearer.yaml
templateDir: modules/openapi-generator/src/main/resources/rust-axum
generateAliasAsModel: true
additionalProperties:
hideGenerationTimestamp: "true"
packageName: apikey-auths
globalProperties:
skipFormModel: false
enablePostProcessFile: true
14 changes: 7 additions & 7 deletions docs/generators/rust-axum.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ These options may be applied as additional-properties (cli) or configOptions (pl
|Body|✓|OAS2
|FormUnencoded|✓|OAS2
|FormMultipart|✓|OAS2
|Cookie||OAS3
|Cookie||OAS3

### Schema Support Feature
| Name | Supported | Defined By |
Expand All @@ -215,14 +215,14 @@ These options may be applied as additional-properties (cli) or configOptions (pl
### Security Feature
| Name | Supported | Defined By |
| ---- | --------- | ---------- |
|BasicAuth||OAS2,OAS3
|BasicAuth||OAS2,OAS3
|ApiKey|✓|OAS2,OAS3
|OpenIDConnect|✗|OAS3
|BearerToken||OAS3
|OAuth2_Implicit||OAS2,OAS3
|OAuth2_Password||OAS2,OAS3
|OAuth2_ClientCredentials||OAS2,OAS3
|OAuth2_AuthorizationCode||OAS2,OAS3
|BearerToken||OAS3
|OAuth2_Implicit||OAS2,OAS3
|OAuth2_Password||OAS2,OAS3
|OAuth2_ClientCredentials||OAS2,OAS3
|OAuth2_AuthorizationCode||OAS2,OAS3
|SignatureAuth|✗|OAS3
|AWSV4Signature|✗|ToolingExtension

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,13 @@
import io.swagger.v3.oas.models.parameters.RequestBody;
import io.swagger.v3.oas.models.responses.ApiResponse;
import io.swagger.v3.oas.models.servers.Server;
import io.swagger.v3.oas.models.tags.Tag;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.lang3.StringUtils;
import org.openapitools.codegen.*;
import org.openapitools.codegen.meta.GeneratorMetadata;
import org.openapitools.codegen.meta.Stability;
import org.openapitools.codegen.meta.features.GlobalFeature;
import org.openapitools.codegen.meta.features.ParameterFeature;
import org.openapitools.codegen.meta.features.SchemaSupportFeature;
import org.openapitools.codegen.meta.features.WireFormatFeature;
import org.openapitools.codegen.meta.features.*;
import org.openapitools.codegen.model.ModelMap;
import org.openapitools.codegen.model.ModelsMap;
import org.openapitools.codegen.model.OperationMap;
Expand All @@ -53,7 +51,6 @@

public class RustAxumServerCodegen extends AbstractRustCodegen implements CodegenConfig {
public static final String PROJECT_NAME = "openapi-server";
private static final String apiPath = "rust-axum";

private String packageName;
private String packageVersion;
Expand Down Expand Up @@ -99,6 +96,9 @@ public RustAxumServerCodegen() {
WireFormatFeature.JSON,
WireFormatFeature.Custom
))
.securityFeatures(EnumSet.of(
SecurityFeature.ApiKey
))
.excludeGlobalFeatures(
GlobalFeature.Info,
GlobalFeature.ExternalDocumentation,
Expand All @@ -113,9 +113,6 @@ public RustAxumServerCodegen() {
.excludeSchemaSupportFeatures(
SchemaSupportFeature.Polymorphism
)
.excludeParameterFeatures(
ParameterFeature.Cookie
)
);

generatorMetadata = GeneratorMetadata.newBuilder(generatorMetadata)
Expand Down Expand Up @@ -436,7 +433,10 @@ public CodegenOperation fromOperation(String path, String httpMethod, Operation
}
pathMethodOpMap
.computeIfAbsent(axumPath, (key) -> new ArrayList<>())
.add(new MethodOperation(op.httpMethod.toLowerCase(Locale.ROOT), underscoredOperationId));
.add(new MethodOperation(
op.httpMethod.toLowerCase(Locale.ROOT),
underscoredOperationId,
op.vendorExtensions));
}

// Determine the types that this operation produces. `getProducesInfo`
Expand Down Expand Up @@ -488,7 +488,7 @@ public CodegenOperation fromOperation(String path, String httpMethod, Operation
);
rsp.vendorExtensions.put("x-response-id", responseId);
}

if (rsp.dataType != null) {
// Get the mimetype which is produced by this response. Note
// that although in general responses produces a set of
Expand Down Expand Up @@ -562,19 +562,39 @@ public CodegenOperation fromOperation(String path, String httpMethod, Operation
}
}
}

for (CodegenProperty header : rsp.headers) {
if (uuidType.equals(header.dataType)) {
additionalProperties.put("apiUsesUuid", true);
}
header.nameInPascalCase = toModelName(header.baseName);
header.nameInLowerCase = header.baseName.toLowerCase(Locale.ROOT);
}
}

for (CodegenParameter header : op.headerParams) {
header.nameInLowerCase = header.baseName.toLowerCase(Locale.ROOT);
}

for (CodegenProperty header : op.responseHeaders) {
if (uuidType.equals(header.dataType)) {
additionalProperties.put("apiUsesUuid", true);
}
header.nameInPascalCase = toModelName(header.baseName);
header.nameInLowerCase = header.baseName.toLowerCase(Locale.ROOT);
}

// Include renderUuidConversionImpl exactly once in the vendorExtensions map when
// at least one `uuid::Uuid` converted from a header value in the resulting Rust code.
final Boolean renderUuidConversionImpl = op.headerParams.stream().anyMatch(h -> h.getDataType().equals(uuidType));
if (renderUuidConversionImpl) {
// Include renderUuidConversionImpl exactly once in the vendorExtensions map when
// at least one `uuid::Uuid` converted from a header value in the resulting Rust code.
final boolean renderUuidConversionImpl = op.headerParams.stream().anyMatch(h -> h.getDataType().equals(uuidType));
if (renderUuidConversionImpl)
additionalProperties.put("renderUuidConversionImpl", "true");
}

return op;
}

@Override
public OperationsMap postProcessOperationsWithModels(OperationsMap operationsMap, List<ModelMap> allModels) {
public OperationsMap postProcessOperationsWithModels(final OperationsMap operationsMap, List<ModelMap> allModels) {
OperationMap operations = operationsMap.getOperations();
operations.put("classnamePascalCase", camelize(operations.getClassname()));
List<CodegenOperation> operationList = operations.getOperation();
Expand All @@ -586,7 +606,7 @@ public OperationsMap postProcessOperationsWithModels(OperationsMap operationsMap
return operationsMap;
}

private void postProcessOperationWithModels(CodegenOperation op) {
private void postProcessOperationWithModels(final CodegenOperation op) {
boolean consumesJson = false;
boolean consumesPlainText = false;
boolean consumesFormUrlEncoded = false;
Expand Down Expand Up @@ -641,6 +661,22 @@ private void postProcessOperationWithModels(CodegenOperation op) {
param.vendorExtensions.put("x-consumes-json", true);
}
}

if (op.authMethods != null) {
for (CodegenSecurity s : op.authMethods) {
if (s.isApiKey && (s.isKeyInCookie || s.isKeyInHeader)) {
if (s.isKeyInCookie) {
op.vendorExtensions.put("x-has-cookie-auth-methods", "true");
op.vendorExtensions.put("x-api-key-cookie-name", toModelName(s.keyParamName));
} else {
op.vendorExtensions.put("x-has-header-auth-methods", "true");
op.vendorExtensions.put("x-api-key-header-name", toModelName(s.keyParamName));
}

op.vendorExtensions.put("x-has-auth-methods", "true");
}
}
}
}

@Override
Expand Down Expand Up @@ -671,14 +707,16 @@ public void addOperationToGroup(String tag, String resourcePath, Operation opera
return;
}

// Get all tags sorted by name
final List<String> tags = op.tags.stream().map(t -> t.getName()).sorted().collect(Collectors.toList());
// Combine into a single group
final String combinedTag = tags.stream().collect(Collectors.joining("-"));
// Get all tags sorted by name & Combine into a single group
final String combinedTag = op.tags.stream()
.map(Tag::getName).sorted()
.collect(Collectors.joining("-"));
// Add to group
super.addOperationToGroup(combinedTag, resourcePath, operation, op, operations);

return;
}

super.addOperationToGroup(tag, resourcePath, operation, op, operations);
}

Expand All @@ -692,7 +730,7 @@ public void addOperationToGroup(String tag, String resourcePath, Operation opera
// restore things to sensible values.
@Override
public CodegenParameter fromRequestBody(RequestBody body, Set<String> imports, String bodyParameterName) {
Schema original_schema = ModelUtils.getSchemaFromRequestBody(body);
final Schema original_schema = ModelUtils.getSchemaFromRequestBody(body);
CodegenParameter codegenParameter = super.fromRequestBody(body, imports, bodyParameterName);

if (StringUtils.isNotBlank(original_schema.get$ref())) {
Expand All @@ -709,12 +747,12 @@ public CodegenParameter fromRequestBody(RequestBody body, Set<String> imports, S
}

@Override
public String toInstantiationType(Schema p) {
public String toInstantiationType(final Schema p) {
if (ModelUtils.isArraySchema(p)) {
Schema inner = ModelUtils.getSchemaItems(p);
final Schema inner = ModelUtils.getSchemaItems(p);
return instantiationTypes.get("array") + "<" + getSchemaType(inner) + ">";
} else if (ModelUtils.isMapSchema(p)) {
Schema inner = ModelUtils.getAdditionalProperties(p);
final Schema inner = ModelUtils.getAdditionalProperties(p);
return instantiationTypes.get("map") + "<" + typeMapping.get("string") + ", " + getSchemaType(inner) + ">";
} else {
return null;
Expand All @@ -727,7 +765,7 @@ public Map<String, Object> postProcessSupportingFileData(Map<String, Object> bun

final List<PathMethodOperations> pathMethodOps = pathMethodOpMap.entrySet().stream()
.map(entry -> {
ArrayList<MethodOperation> methodOps = entry.getValue();
final ArrayList<MethodOperation> methodOps = entry.getValue();
methodOps.sort(Comparator.comparing(a -> a.method));
return new PathMethodOperations(entry.getKey(), methodOps);
})
Expand All @@ -739,7 +777,7 @@ public Map<String, Object> postProcessSupportingFileData(Map<String, Object> bun
}

@Override
public String toDefaultValue(Schema p) {
public String toDefaultValue(final Schema p) {
String defaultValue = null;
if ((ModelUtils.isNullable(p)) && (p.getDefault() != null) && ("null".equalsIgnoreCase(p.getDefault().toString())))
return "Nullable::Null";
Expand Down Expand Up @@ -899,7 +937,7 @@ protected void updateParameterForString(CodegenParameter codegenParameter, Schem
}

@Override
protected void updatePropertyForAnyType(CodegenProperty property, Schema p) {
protected void updatePropertyForAnyType(final CodegenProperty property, final Schema p) {
// The 'null' value is allowed when the OAS schema is 'any type'.
// See https://github.com/OAI/OpenAPI-Specification/issues/1389
if (Boolean.FALSE.equals(p.getNullable())) {
Expand All @@ -917,7 +955,7 @@ protected void updatePropertyForAnyType(CodegenProperty property, Schema p) {
}

@Override
protected String getParameterDataType(Parameter parameter, Schema schema) {
protected String getParameterDataType(final Parameter parameter, final Schema schema) {
if (parameter.get$ref() != null) {
String refName = ModelUtils.getSimpleRef(parameter.get$ref());
return toModelName(refName);
Expand All @@ -938,10 +976,12 @@ static class PathMethodOperations {
static class MethodOperation {
public String method;
public String operationID;
public Map<String, Object> vendorExtensions;

MethodOperation(String method, String operationID) {
MethodOperation(String method, String operationID, Map<String, Object> vendorExtensions) {
this.method = method;
this.operationID = operationID;
this.vendorExtensions = vendorExtensions;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,19 @@ pub mod {{classFilename}};
{{/apis}}
{{/apiInfo}}

{{#authMethods}}
{{#isApiKey}}
{{#isKeyInCookie}}
/// Cookie Authentication.
pub trait CookieAuthentication {
fn extract_token_from_cookie(&self, cookies: &axum_extra::extract::CookieJar, key: &str) -> Option<String>;
}
{{/isKeyInCookie}}
{{#isKeyInHeader}}
/// API Key Authentication - Header.
pub trait ApiKeyAuthHeader {
fn extract_token_from_header(&self, headers: &axum::http::header::HeaderMap, key: &str) -> Option<String>;
}
{{/isKeyInHeader}}
{{/isApiKey}}
{{/authMethods}}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ pub trait {{classnamePascalCase}} {
method: Method,
host: Host,
cookies: CookieJar,
{{#vendorExtensions}}
{{#x-has-cookie-auth-methods}}
token_in_cookie: Option<String>,
{{/x-has-cookie-auth-methods}}
{{#x-has-header-auth-methods}}
token_in_header: Option<String>,
{{/x-has-header-auth-methods}}
{{/vendorExtensions}}
{{#headerParams.size}}
header_params: models::{{{operationIdCamelCase}}}HeaderParams,
{{/headerParams.size}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ async fn {{#vendorExtensions}}{{{x-operation-id}}}{{/vendorExtensions}}<I, A>(
{{#headerParams.size}}
headers: HeaderMap,
{{/headerParams.size}}
{{^headerParams.size}}
{{#vendorExtensions}}
{{#x-has-header-auth-methods}}
headers: HeaderMap,
{{/x-has-header-auth-methods}}
{{/vendorExtensions}}
{{/headerParams.size}}
{{#pathParams.size}}
Path(path_params): Path<models::{{{operationIdCamelCase}}}PathParams>,
{{/pathParams.size}}
Expand Down Expand Up @@ -47,14 +54,34 @@ async fn {{#vendorExtensions}}{{{x-operation-id}}}{{/vendorExtensions}}<I, A>(
) -> Result<Response, StatusCode>
where
I: AsRef<A> + Send + Sync,
A: apis::{{classFilename}}::{{classnamePascalCase}},
A: apis::{{classFilename}}::{{classnamePascalCase}}{{#vendorExtensions}}{{#x-has-cookie-auth-methods}}+ apis::CookieAuthentication{{/x-has-cookie-auth-methods}}{{#x-has-header-auth-methods}}+ apis::ApiKeyAuthHeader{{/x-has-header-auth-methods}}{{/vendorExtensions}},
{
{{#vendorExtensions}}
{{#x-has-auth-methods}}
// Authentication
{{/x-has-auth-methods}}
{{#x-has-cookie-auth-methods}}
let token_in_cookie = api_impl.as_ref().extract_token_from_cookie(&cookies, "{{x-api-key-cookie-name}}");
{{/x-has-cookie-auth-methods}}
{{#x-has-header-auth-methods}}
let token_in_header = api_impl.as_ref().extract_token_from_header(&headers, "{{x-api-key-header-name}}");
{{/x-has-header-auth-methods}}
{{#x-has-auth-methods}}
if let ({{#x-has-cookie-auth-methods}}None,{{/x-has-cookie-auth-methods}}{{#x-has-header-auth-methods}}None,{{/x-has-header-auth-methods}}) = ({{#x-has-cookie-auth-methods}}&token_in_cookie,{{/x-has-cookie-auth-methods}}{{#x-has-header-auth-methods}}&token_in_header,{{/x-has-header-auth-methods}}) {
return Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR);
}
{{/x-has-auth-methods}}
{{/vendorExtensions}}

{{#headerParams}}
{{#-first}}
// Header parameters
let header_params = {
{{/-first}}
let header_{{{paramName}}} = headers.get(HeaderName::from_static("{{{nameInLowerCase}}}"));
let header_{{{paramName}}} = headers.get(HeaderName::from_static("{{{baseName}}}"));

let header_{{{paramName}}} = match header_{{{paramName}}} {
Some(v) => match header::IntoHeaderValue::<{{{dataType}}}>::try_from((*v).clone()) {
Expand Down Expand Up @@ -154,6 +181,14 @@ where
method,
host,
cookies,
{{#vendorExtensions}}
{{#x-has-cookie-auth-methods}}
token_in_cookie,
{{/x-has-cookie-auth-methods}}
{{#x-has-header-auth-methods}}
token_in_header,
{{/x-has-header-auth-methods}}
{{/vendorExtensions}}
{{#headerParams.size}}
header_params,
{{/headerParams.size}}
Expand Down Expand Up @@ -232,7 +267,7 @@ where
{
let mut response_headers = response.headers_mut().unwrap();
response_headers.insert(
HeaderName::from_static("{{{nameInLowerCase}}}"),
HeaderName::from_static("{{{baseName}}}"),
{{name}}
);
}
Expand Down
Loading

0 comments on commit cded99c

Please sign in to comment.