Skip to content
This repository has been archived by the owner on Dec 16, 2024. It is now read-only.

Commit

Permalink
feat(go): extract comments from struct definitions (#22)
Browse files Browse the repository at this point in the history
Fixes #21
  • Loading branch information
alecthomas authored May 4, 2023
1 parent 66e143e commit 1cadad1
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 171 deletions.
6 changes: 6 additions & 0 deletions console/src/protos/xyz/block/ftl/v1/schema/schema_pb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ export class Data extends Message<Data> {
*/
metadata: Metadata[] = [];

/**
* @generated from field: repeated string comments = 5;
*/
comments: string[] = [];

constructor(data?: PartialMessage<Data>) {
super();
proto3.util.initPartial(data, this);
Expand All @@ -113,6 +118,7 @@ export class Data extends Message<Data> {
{ no: 2, name: "name", kind: "scalar", T: 9 /* ScalarType.STRING */ },
{ no: 3, name: "fields", kind: "message", T: Field, repeated: true },
{ no: 4, name: "metadata", kind: "message", T: Metadata, repeated: true },
{ no: 5, name: "comments", kind: "scalar", T: 9 /* ScalarType.STRING */, repeated: true },
]);

static fromBinary(bytes: Uint8Array, options?: Partial<BinaryReadOptions>): Data {
Expand Down
5 changes: 5 additions & 0 deletions examples/echo/echo.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// This is the echo module.
//
//ftl:module echo
package echo

Expand All @@ -11,6 +13,7 @@ import (
ftl "github.com/TBD54566975/ftl/sdk-go"
)

// An echo request.
type EchoRequest struct {
Name string `json:"name"`
}
Expand All @@ -19,6 +22,8 @@ type EchoResponse struct {
Message string `json:"message"`
}

// Echo returns a greeting with the current time.
//
//ftl:verb
func Echo(ctx context.Context, req EchoRequest) (EchoResponse, error) {
tresp, err := ftl.Call(ctx, timemodule.Time, timemodule.TimeRequest{})
Expand Down
2 changes: 2 additions & 0 deletions examples/time/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ type TimeResponse struct {
Time int `json:"time"`
}

// Time returns the current time.
//
//ftl:verb
func Time(ctx context.Context, req TimeRequest) (TimeResponse, error) {
return TimeResponse{Time: int(time.Now().Unix())}, nil
Expand Down
276 changes: 143 additions & 133 deletions protos/xyz/block/ftl/v1/schema/schema.pb.go

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions protos/xyz/block/ftl/v1/schema/schema.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ message Data {
string name = 2;
repeated Field fields = 3;
repeated Metadata metadata = 4;
repeated string comments = 5;
}

message DataRef {
Expand Down
1 change: 1 addition & 0 deletions schema/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ func (d *Data) schemaChildren() []Node {
}
func (d *Data) String() string {
w := &strings.Builder{}
fmt.Fprint(w, encodeComments(d.Comments))
fmt.Fprintf(w, "data %s {\n", d.Name)
for _, f := range d.Fields {
fmt.Fprintln(w, indent(f.String()))
Expand Down
5 changes: 3 additions & 2 deletions schema/protobuf_dec.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ func verbToSchema(s *pschema.Verb) *Verb {

func dataToSchema(s *pschema.Data) *Data {
return &Data{
Name: s.Name,
Fields: fieldListToSchema(s.Fields),
Name: s.Name,
Fields: fieldListToSchema(s.Fields),
Comments: s.Comments,
}
}

Expand Down
5 changes: 3 additions & 2 deletions schema/protobuf_enc.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ func (v *Verb) ToProto() proto.Message {

func (d *Data) ToProto() proto.Message {
return &pschema.Data{
Name: d.Name,
Fields: nodeListToProto[*pschema.Field](d.Fields),
Name: d.Name,
Fields: nodeListToProto[*pschema.Field](d.Fields),
Comments: d.Comments,
}
}

Expand Down
1 change: 1 addition & 0 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ type DataRef Ref
type Data struct {
Pos Position `json:"pos,omitempty" parser:"" protobuf:"1,optional"`

Comments []string `parser:"@Comment*" json:"comments,omitempty" protobuf:"5"`
Name string `parser:"'data' @Ident '{'" json:"name,omitempty" protobuf:"2"`
Fields []*Field `parser:"@@* '}'" json:"fields,omitempty" protobuf:"3"`
Metadata []Metadata `parser:"@@*" json:"metadata,omitempty" protobuf:"4"`
Expand Down
2 changes: 2 additions & 0 deletions scripts/autofmt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ buf format -w
# Revert changes to generated files.
find protos \( -name '*.pb.go' -o -name '*.connect.go' \) -print0 | xargs -0 -r git checkout > /dev/null

(cd protos && buf generate)

git diff
130 changes: 96 additions & 34 deletions sdk-go/static.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/iancoleman/strcase"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"

"github.com/TBD54566975/ftl/common/goast"
Expand Down Expand Up @@ -46,6 +47,7 @@ func ExtractModule(dir string) (*schema.Module, error) {
}
module := &schema.Module{}
for _, pkg := range pkgs {
pctx := &parseContext{pkg: pkg, pkgs: pkgs, module: module}
for _, file := range pkg.Syntax {
var verb *schema.Verb
err := goast.Visit(file, func(node ast.Node, next func() error) (err error) {
Expand All @@ -56,17 +58,17 @@ func ExtractModule(dir string) (*schema.Module, error) {
}()
switch node := node.(type) {
case *ast.CallExpr:
if err := visitCallExpr(verb, node, pkg); err != nil {
if err := visitCallExpr(pctx, verb, node); err != nil {
return err
}

case *ast.File:
if err := visitFile(module, node, pkg); err != nil {
if err := visitFile(pctx, node); err != nil {
return err
}

case *ast.FuncDecl:
verb, err = visitFuncDecl(pkg, module, node)
verb, err = visitFuncDecl(pctx, node)
if err != nil {
return err
}
Expand All @@ -93,8 +95,8 @@ func ExtractModule(dir string) (*schema.Module, error) {
return module, schema.ValidateModule(module)
}

func visitCallExpr(verb *schema.Verb, node *ast.CallExpr, pkg *packages.Package) error {
_, fn := deref[*types.Func](pkg, node.Fun)
func visitCallExpr(pctx *parseContext, verb *schema.Verb, node *ast.CallExpr) error {
_, fn := deref[*types.Func](pctx.pkg, node.Fun)
if fn == nil {
return nil
}
Expand All @@ -104,12 +106,12 @@ func visitCallExpr(verb *schema.Verb, node *ast.CallExpr, pkg *packages.Package)
if len(node.Args) != 3 {
return errors.New("Call must have exactly three arguments")
}
_, verbFn := deref[*types.Func](pkg, node.Args[1])
_, verbFn := deref[*types.Func](pctx.pkg, node.Args[1])
if verbFn == nil {
return errors.Errorf("Call first argument must be a function but is %s", node.Args[1])
}
moduleName := verbFn.Pkg().Name()
if moduleName == pkg.Name {
if moduleName == pctx.pkg.Name {
moduleName = ""
}
ref := &schema.VerbRef{
Expand All @@ -121,25 +123,25 @@ func visitCallExpr(verb *schema.Verb, node *ast.CallExpr, pkg *packages.Package)
return nil
}

func visitFile(module *schema.Module, node *ast.File, pkg *packages.Package) error {
func visitFile(pctx *parseContext, node *ast.File) error {
if node.Doc == nil {
return nil
}
directives, err := parseDirectives(fset, node.Doc)
if err != nil {
return errors.WithStack(err)
}
module.Comments = parseComments(node.Doc)
pctx.module.Comments = parseComments(node.Doc)
for _, dir := range directives {
switch dir.kind {
case "module":
if dir.id == "" {
return errors.Errorf("%s: module not specified", dir)
}
if dir.id != pkg.Name {
return errors.Errorf("%s: FTL module name %q does not match Go package name %q", dir, dir.id, pkg.Name)
if dir.id != pctx.pkg.Name {
return errors.Errorf("%s: FTL module name %q does not match Go package name %q", dir, dir.id, pctx.pkg.Name)
}
module.Name = dir.id
pctx.module.Name = dir.id

default:
return errors.Errorf("%s: invalid directive", dir)
Expand Down Expand Up @@ -186,12 +188,12 @@ func goPosToSchemaPos(pos token.Pos) schema.Position {
}

// "verbIndex" is the index into the Module.Decls of the verb that was parsed.
func visitFuncDecl(pkg *packages.Package, module *schema.Module, node *ast.FuncDecl) (verb *schema.Verb, err error) {
func visitFuncDecl(pctx *parseContext, node *ast.FuncDecl) (verb *schema.Verb, err error) {
if node.Doc == nil {
return nil, nil
}
fnt := pkg.TypesInfo.Defs[node.Name].(*types.Func) //nolint:forcetypeassert
sig := fnt.Type().(*types.Signature) //nolint:forcetypeassert
fnt := pctx.pkg.TypesInfo.Defs[node.Name].(*types.Func) //nolint:forcetypeassert
sig := fnt.Type().(*types.Signature) //nolint:forcetypeassert
if sig.Recv() != nil {
return nil, errors.Errorf("ftl:verb cannot be a method")
}
Expand All @@ -200,11 +202,11 @@ func visitFuncDecl(pkg *packages.Package, module *schema.Module, node *ast.FuncD
if err := checkSignature(sig); err != nil {
return nil, err
}
req, err := parseStruct(pkg, module, node, params.At(1).Type())
req, err := parseStruct(pctx, node, params.At(1).Type())
if err != nil {
return nil, err
}
resp, err := parseStruct(pkg, module, node, results.At(0).Type())
resp, err := parseStruct(pctx, node, results.At(0).Type())
if err != nil {
return nil, err
}
Expand All @@ -215,7 +217,7 @@ func visitFuncDecl(pkg *packages.Package, module *schema.Module, node *ast.FuncD
Request: req,
Response: resp,
}
module.Decls = append(module.Decls, verb)
pctx.module.Decls = append(pctx.module.Decls, verb)
return verb, nil
}

Expand All @@ -227,22 +229,44 @@ func parseComments(doc *ast.CommentGroup) []string {
return comments
}

func parseStruct(pkg *packages.Package, module *schema.Module, node ast.Node, tnode types.Type) (*schema.DataRef, error) {
func parseStruct(pctx *parseContext, node ast.Node, tnode types.Type) (*schema.DataRef, error) {
named, ok := tnode.(*types.Named)
if !ok {
return nil, errors.Errorf("expected named type but got %s", tnode)
}
s, ok := tnode.Underlying().(*types.Struct)
if !ok {
return nil, errors.Errorf("expected struct but got %s", tnode)
}
out := &schema.Data{
Pos: goPosToSchemaPos(node.Pos()),
Name: named.Obj().Name(),
}

// Find type declaration so we can extract comments.
pos := named.Obj().Pos()
pkg, path, _ := pctx.pathEnclosingInterval(pos, pos)
if pkg != nil {
for i := len(path) - 1; i >= 0; i-- {
// We have to check both the type spec and the gen decl because the
// type could be declared as either "type Foo struct { ... }" or
// "type ( Foo struct { ... } )"
switch path := path[i].(type) {
case *ast.TypeSpec:
if path.Doc != nil {
out.Comments = parseComments(path.Doc)
}
case *ast.GenDecl:
if path.Doc != nil {
out.Comments = parseComments(path.Doc)
}
}
}
}

s, ok := tnode.Underlying().(*types.Struct)
if !ok {
return nil, errors.Errorf("expected struct but got %s", tnode)
}
for i := 0; i < s.NumFields(); i++ {
f := s.Field(i)
ft, err := parseType(pkg, module, node, f.Type())
ft, err := parseType(pctx, node, f.Type())
if err != nil {
return nil, errors.WithStack(err)
}
Expand All @@ -252,14 +276,14 @@ func parseStruct(pkg *packages.Package, module *schema.Module, node ast.Node, tn
Type: ft,
})
}
module.AddData(out)
pctx.module.AddData(out)
return &schema.DataRef{
Pos: goPosToSchemaPos(node.Pos()),
Name: out.Name,
}, nil
}

func parseType(pkg *packages.Package, module *schema.Module, node ast.Node, tnode types.Type) (schema.Type, error) {
func parseType(pctx *parseContext, node ast.Node, tnode types.Type) (schema.Type, error) {
switch tnode := tnode.Underlying().(type) {
case *types.Basic:
switch tnode.Kind() {
Expand All @@ -280,25 +304,25 @@ func parseType(pkg *packages.Package, module *schema.Module, node ast.Node, tnod
}

case *types.Struct:
return parseStruct(pkg, module, node, tnode)
return parseStruct(pctx, node, tnode)

case *types.Map:
return parseMap(pkg, module, node, tnode)
return parseMap(pctx, node, tnode)

case *types.Slice:
return parseSlice(pkg, module, node, tnode)
return parseSlice(pctx, node, tnode)

default:
return nil, errors.Errorf("unsupported type %s", node)
}
}

func parseMap(pkg *packages.Package, module *schema.Module, node ast.Node, tnode *types.Map) (*schema.Map, error) {
key, err := parseType(pkg, module, node, tnode.Key())
func parseMap(pctx *parseContext, node ast.Node, tnode *types.Map) (*schema.Map, error) {
key, err := parseType(pctx, node, tnode.Key())
if err != nil {
return nil, errors.WithStack(err)
}
value, err := parseType(pkg, module, node, tnode.Elem())
value, err := parseType(pctx, node, tnode.Elem())
if err != nil {
return nil, errors.WithStack(err)
}
Expand All @@ -309,8 +333,8 @@ func parseMap(pkg *packages.Package, module *schema.Module, node ast.Node, tnode
}, nil
}

func parseSlice(pkg *packages.Package, module *schema.Module, node ast.Node, tnode *types.Slice) (*schema.Array, error) {
value, err := parseType(pkg, module, node, tnode.Elem())
func parseSlice(pctx *parseContext, node ast.Node, tnode *types.Slice) (*schema.Array, error) {
value, err := parseType(pctx, node, tnode.Elem())
if err != nil {
return nil, errors.WithStack(err)
}
Expand Down Expand Up @@ -464,3 +488,41 @@ func deref[T types.Object](pkg *packages.Package, node ast.Expr) (string, T) {
return "", obj
}
}

type parseContext struct {
pkg *packages.Package
pkgs []*packages.Package
module *schema.Module
}

// pathEnclosingInterval returns the PackageInfo and ast.Node that
// contain source interval [start, end), and all the node's ancestors
// up to the AST root. It searches all ast.Files of all packages in prog.
// exact is defined as for astutil.PathEnclosingInterval.
//
// The zero value is returned if not found.
func (p *parseContext) pathEnclosingInterval(start, end token.Pos) (pkg *packages.Package, path []ast.Node, exact bool) {
for _, info := range p.pkgs {
for _, f := range info.Syntax {
if f.Pos() == token.NoPos {
// This can happen if the parser saw
// too many errors and bailed out.
// (Use parser.AllErrors to prevent that.)
continue
}
if !tokenFileContainsPos(fset.File(f.Pos()), start) {
continue
}
if path, exact := astutil.PathEnclosingInterval(f, start, end); path != nil {
return info, path, exact
}
}
}
return nil, nil, false
}

func tokenFileContainsPos(f *token.File, pos token.Pos) bool {
p := int(pos)
base := f.Base()
return base <= p && p < base+f.Size()
}

0 comments on commit 1cadad1

Please sign in to comment.