Skip to content

Commit

Permalink
feat: add basic generics support (#1225)
Browse files Browse the repository at this point in the history
  • Loading branch information
ubogdan authored Jun 16, 2022
1 parent 67cb768 commit ff41d9c
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Dockerfile References: https://docs.docker.com/engine/reference/builder/

# Start from the latest golang base image
FROM golang:1.17-alpine as builder
FROM golang:1.18.3-alpine as builder

# Set the Current Working Directory inside the container
WORKDIR /app
Expand Down
109 changes: 109 additions & 0 deletions generics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//go:build go1.18
// +build go1.18

package swag

import (
"go/ast"
"strings"
)

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
fullName := typeSpecDef.FullName()

if typeSpecDef.TypeSpec.TypeParams != nil {
fullName = fullName + "["
for i, typeParam := range typeSpecDef.TypeSpec.TypeParams.List {
if i > 0 {
fullName = fullName + "-"
}

fullName = fullName + typeParam.Names[0].Name
}
fullName = fullName + "]"
}

return fullName
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
genericParams := strings.Split(strings.TrimRight(fullGenericForm, "]"), "[")
if len(genericParams) == 1 {
return nil
}

genericParams = strings.Split(genericParams[1], ",")
for i, p := range genericParams {
genericParams[i] = strings.TrimSpace(p)
}
genericParamTypeDefs := map[string]*TypeSpecDef{}

if len(genericParams) != len(original.TypeSpec.TypeParams.List) {
return nil
}

for i, genericParam := range genericParams {
tdef, ok := pkgDefs.uniqueDefinitions[genericParam]
if !ok {
return nil
}

genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = tdef
}

parametrizedTypeSpec := &TypeSpecDef{
File: original.File,
PkgPath: original.PkgPath,
TypeSpec: &ast.TypeSpec{
Doc: original.TypeSpec.Doc,
Comment: original.TypeSpec.Comment,
Assign: original.TypeSpec.Assign,
},
}

ident := &ast.Ident{
NamePos: original.TypeSpec.Name.NamePos,
Obj: original.TypeSpec.Name.Obj,
}

genNameParts := strings.Split(fullGenericForm, "[")
if strings.Contains(genNameParts[0], ".") {
genNameParts[0] = strings.Split(genNameParts[0], ".")[1]
}

ident.Name = genNameParts[0] + "-" + strings.Replace(strings.Join(genericParams, "-"), ".", "_", -1)
ident.Name = strings.Replace(strings.Replace(ident.Name, "\t", "", -1), " ", "", -1)

parametrizedTypeSpec.TypeSpec.Name = ident

origStructType := original.TypeSpec.Type.(*ast.StructType)

newStructTypeDef := &ast.StructType{
Struct: origStructType.Struct,
Incomplete: origStructType.Incomplete,
Fields: &ast.FieldList{
Opening: origStructType.Fields.Opening,
Closing: origStructType.Fields.Closing,
},
}

for _, field := range origStructType.Fields.List {
newField := &ast.Field{
Doc: field.Doc,
Names: field.Names,
Tag: field.Tag,
Comment: field.Comment,
}
if genTypeSpec, ok := genericParamTypeDefs[field.Type.(*ast.Ident).Name]; ok {
newField.Type = genTypeSpec.TypeSpec.Type
} else {
newField.Type = field.Type
}

newStructTypeDef.Fields.List = append(newStructTypeDef.Fields.List, newField)
}

parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef

return parametrizedTypeSpec
}
12 changes: 12 additions & 0 deletions generics_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
//go:build !go1.18
// +build !go1.18

package swag

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
return original
}
210 changes: 210 additions & 0 deletions generics_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
//go:build go1.18
// +build go1.18

package swag

import (
"encoding/json"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParseGenericsBasic(t *testing.T) {
t.Parallel()

expected := `{
"swagger": "2.0",
"info": {
"description": "This is a sample server Petstore server.",
"title": "Swagger Example API",
"contact": {},
"version": "1.0"
},
"host": "localhost:4000",
"basePath": "/api",
"paths": {
"/posts/{post_id}": {
"get": {
"description": "get string by ID",
"consumes": [
"application/json"
],
"produces": [
"application/json"
],
"summary": "Add a new pet to the store",
"parameters": [
{
"type": "integer",
"format": "int64",
"description": "Some ID",
"name": "post_id",
"in": "path",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"$ref": "#/definitions/web.GenericResponse-web_Post"
}
},
"222": {
"description": "",
"schema": {
"$ref": "#/definitions/web.GenericResponseMulti-web_Post-web_Post"
}
},
"400": {
"description": "We need ID!!",
"schema": {
"$ref": "#/definitions/web.APIError"
}
},
"404": {
"description": "Can not find ID",
"schema": {
"$ref": "#/definitions/web.APIError"
}
}
}
}
}
},
"definitions": {
"web.APIError": {
"description": "API error with information about it",
"type": "object",
"properties": {
"createdAt": {
"description": "Error time",
"type": "string"
},
"error": {
"description": "Error an Api error",
"type": "string"
},
"errorCtx": {
"description": "Error ` + "`context`" + ` tick comment",
"type": "string"
},
"errorNo": {
"description": "Error ` + "`number`" + ` tick comment",
"type": "integer"
}
}
},
"web.GenericResponse-web_Post": {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"data": {
"description": "Post data",
"type": "object",
"properties": {
"name": {
"description": "Post tag",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"id": {
"type": "integer",
"format": "int64",
"example": 1
},
"name": {
"description": "Post name",
"type": "string",
"example": "poti"
}
}
},
"status": {
"type": "string"
}
}
},
"web.GenericResponseMulti-web_Post-web_Post": {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"data": {
"description": "Post data",
"type": "object",
"properties": {
"name": {
"description": "Post tag",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"id": {
"type": "integer",
"format": "int64",
"example": 1
},
"name": {
"description": "Post name",
"type": "string",
"example": "poti"
}
}
},
"meta": {
"type": "object",
"properties": {
"data": {
"description": "Post data",
"type": "object",
"properties": {
"name": {
"description": "Post tag",
"type": "array",
"items": {
"type": "string"
}
}
}
},
"id": {
"type": "integer",
"format": "int64",
"example": 1
},
"name": {
"description": "Post name",
"type": "string",
"example": "poti"
}
}
},
"status": {
"type": "string"
}
}
}
}
}`

searchDir := "testdata/generics_basic"
p := New()
err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth)
assert.NoError(t, err)
b, _ := json.MarshalIndent(p.swagger, "", " ")
os.WriteFile("testdata/generics_basic/swagger.json", b, 0644)
assert.Equal(t, expected, string(b))
}
15 changes: 12 additions & 3 deletions operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ func (operation *Operation) ParseParamComment(commentLine string, astFile *ast.F
if objectType == PRIMITIVE {
param.Schema = PrimitiveSchema(refType)
} else {
schema, err := operation.parseAPIObjectSchema(objectType, refType, astFile)
schema, err := operation.parseAPIObjectSchema(commentLine, objectType, refType, astFile)
if err != nil {
return err
}
Expand Down Expand Up @@ -933,7 +933,16 @@ func (operation *Operation) parseCombinedObjectSchema(refType string, astFile *a
}), nil
}

func (operation *Operation) parseAPIObjectSchema(schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
func (operation *Operation) parseAPIObjectSchema(commentLine, schemaType, refType string, astFile *ast.File) (*spec.Schema, error) {
if strings.HasSuffix(refType, ",") && strings.Contains(refType, "[") {
// regexp may have broken generics syntax. find closing bracket and add it back
allMatchesLenOffset := strings.Index(commentLine, refType) + len(refType)
lostPartEndIdx := strings.Index(commentLine[allMatchesLenOffset:], "]")
if lostPartEndIdx >= 0 {
refType += commentLine[allMatchesLenOffset : allMatchesLenOffset+lostPartEndIdx+1]
}
}

switch schemaType {
case OBJECT:
if !strings.HasPrefix(refType, "[]") {
Expand Down Expand Up @@ -969,7 +978,7 @@ func (operation *Operation) ParseResponseComment(commentLine string, astFile *as

description := strings.Trim(matches[4], "\"")

schema, err := operation.parseAPIObjectSchema(strings.Trim(matches[2], "{}"), matches[3], astFile)
schema, err := operation.parseAPIObjectSchema(commentLine, strings.Trim(matches[2], "{}"), matches[3], astFile)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit ff41d9c

Please sign in to comment.