Skip to content

Commit

Permalink
Generate a single swagger.json file for all frameworks (#1437)
Browse files Browse the repository at this point in the history
* Update tools for swagger generation

* Regenerate SDK

* Cleanup

* Use absolute path for swagger file

* Refactor imports
  • Loading branch information
alembiewski authored Oct 9, 2021
1 parent 7cdc253 commit 75d4c3b
Show file tree
Hide file tree
Showing 48 changed files with 124 additions and 1,130 deletions.
23 changes: 3 additions & 20 deletions hack/python-sdk/gen-sdk.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ SWAGGER_JAR_URL="https://repo1.maven.org/maven2/org/openapitools/openapi-generat
SWAGGER_CODEGEN_JAR="${repo_root}/hack/python-sdk/openapi-generator-cli.jar"
SWAGGER_CODEGEN_CONF="${repo_root}/hack/python-sdk/swagger_config.json"
SDK_OUTPUT_PATH="${repo_root}/sdk/python"
FRAMEWORKS=(tensorflow pytorch mxnet xgboost)
VERSION=1.3.0
SWAGGER_CODEGEN_FILE="${repo_root}/hack/python-sdk/swagger.json"

if [ -z "${GOPATH:-}" ]; then
export GOPATH=$(go env GOPATH)
Expand All @@ -39,32 +39,15 @@ if [[ ! -f "$SWAGGER_CODEGEN_JAR" ]]; then
wget -O "${SWAGGER_CODEGEN_JAR}" ${SWAGGER_JAR_URL}
fi


for FRAMEWORK in ${FRAMEWORKS[@]}; do
SWAGGER_CODEGEN_FILE="pkg/apis/${FRAMEWORK}/v1/swagger.json"
echo "Generating swagger file for ${FRAMEWORK} ..."
go run "${repo_root}"/hack/python-sdk/main.go "${FRAMEWORK}" ${VERSION} > "${SWAGGER_CODEGEN_FILE}"
done

echo "Merging swagger files from different frameworks into one"
download_url=$(curl -s https://api.github.com/repos/go-swagger/go-swagger/releases/latest | \
jq -r '.assets[] | select(.name | contains("'"$(uname | tr '[:upper:]' '[:lower:]')"'_amd64")) | .browser_download_url')
curl -o /tmp/swagger -L'#' "$download_url"
chmod +x /tmp/swagger

# it will report warning like 'v1.SchedulingPolicy' already exists in primary or higher priority mixin, skipping
# error code is not 0 but t's acceptable.
/tmp/swagger mixin "${repo_root}"/pkg/apis/tensorflow/v1/swagger.json "${repo_root}"/pkg/apis/pytorch/v1/swagger.json \
"${repo_root}"/pkg/apis/mxnet/v1/swagger.json "${repo_root}"/pkg/apis/xgboost/v1/swagger.json \
--output "${repo_root}"/hack/python-sdk/swagger.json --quiet || true
echo "Generating swagger file ..."
go run "${repo_root}"/hack/python-sdk/main.go ${VERSION} > "${SWAGGER_CODEGEN_FILE}"

echo "Removing previously generated files ..."
rm -rf "${SDK_OUTPUT_PATH}"/docs/V1*.md "${SDK_OUTPUT_PATH}"/kubeflow/training/models "${SDK_OUTPUT_PATH}"/kubeflow/training/*.py "${SDK_OUTPUT_PATH}"/test/*.py
echo "Generating Python SDK for Training Operator ..."
java -jar "${SWAGGER_CODEGEN_JAR}" generate -i "${repo_root}"/hack/python-sdk/swagger.json -g python -o "${SDK_OUTPUT_PATH}" -c "${SWAGGER_CODEGEN_CONF}"

echo "Kubeflow Training Operator Python SDK is generated successfully to folder ${SDK_OUTPUT_PATH}/."
rm /tmp/swagger

echo "Running post-generation script ..."
"${repo_root}"/hack/python-sdk/post_gen.py
71 changes: 37 additions & 34 deletions hack/python-sdk/main.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
Copyright 2019 kubeflow.org.
Copyright 2021 kubeflow.org.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -19,52 +19,52 @@ package main
import (
"encoding/json"
"fmt"
mxnet "github.com/kubeflow/tf-operator/pkg/apis/mxnet/v1"
pytorch "github.com/kubeflow/tf-operator/pkg/apis/pytorch/v1"
tensorflow "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
xgboost "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1"
"os"
"strings"

"github.com/go-openapi/spec"
mxJob "github.com/kubeflow/tf-operator/pkg/apis/mxnet/v1"
pytorchJob "github.com/kubeflow/tf-operator/pkg/apis/pytorch/v1"
tfjob "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/v1"
xgboostJob "github.com/kubeflow/tf-operator/pkg/apis/xgboost/v1"
"k8s.io/klog"
"k8s.io/kube-openapi/pkg/common"
)

// Generate OpenAPI spec definitions for TFJob Resource
// Generate OpenAPI spec definitions for API resources
func main() {
if len(os.Args) <= 2 {
klog.Fatal("Supply a framework and version")
if len(os.Args) <= 1 {
klog.Fatal("Supply a version")
}
framework := os.Args[1]
version := os.Args[2]
version := os.Args[1]
if !strings.HasPrefix(version, "v") {
version = "v" + version
}
var oAPIDefs map[string]common.OpenAPIDefinition
var oAPIDefs = map[string]common.OpenAPIDefinition{}
defs := spec.Definitions{}

switch framework {
case "tensorflow":
oAPIDefs = tfjob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
case "pytorch":
oAPIDefs = pytorchJob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
case "mxnet":
oAPIDefs = mxJob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
case "xgboost":
oAPIDefs = xgboostJob.GetOpenAPIDefinitions(func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name, framework)))
})
refCallback := func(name string) spec.Ref {
return spec.MustCreateRef("#/definitions/" + common.EscapeJsonPointer(swaggify(name)))
}

for k, v := range tensorflow.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

for k, v := range pytorch.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

for k, v := range mxnet.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

for k, v := range xgboost.GetOpenAPIDefinitions(refCallback) {
oAPIDefs[k] = v
}

defs := spec.Definitions{}
for defName, val := range oAPIDefs {
defs[swaggify(defName, framework)] = val.Schema
defs[swaggify(defName)] = val.Schema
}
swagger := spec.Swagger{
SwaggerProps: spec.SwaggerProps{
Expand All @@ -73,8 +73,8 @@ func main() {
Paths: &spec.Paths{Paths: map[string]spec.PathItem{}},
Info: &spec.Info{
InfoProps: spec.InfoProps{
Title: framework,
Description: fmt.Sprintf("Python SDK for %v", framework),
Title: "Kubeflow Training SDK",
Description: "Python SDK for Kubeflow Training",
Version: version,
},
},
Expand All @@ -87,8 +87,11 @@ func main() {
fmt.Println(string(jsonBytes))
}

func swaggify(name, framework string) string {
name = strings.Replace(name, fmt.Sprintf("github.com/kubeflow/tf-operator/pkg/apis/%s/", framework), "", -1)
func swaggify(name string) string {
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/tensorflow/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/pytorch/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/mxnet/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/tf-operator/pkg/apis/xgboost/", "", -1)
name = strings.Replace(name, "github.com/kubeflow/common/pkg/apis/common/", "", -1)
name = strings.Replace(name, "k8s.io/api/core/", "", -1)
name = strings.Replace(name, "k8s.io/apimachinery/pkg/apis/meta/", "", -1)
Expand Down
Loading

0 comments on commit 75d4c3b

Please sign in to comment.