diff --git a/metaschema/templates/template.go b/metaschema/templates/template.go index 3a7a11c..4a8e0a9 100644 --- a/metaschema/templates/template.go +++ b/metaschema/templates/template.go @@ -16,22 +16,20 @@ import ( ) func GenerateAll(metaschema *parser.Metaschema, baseDir string) error { - return GenerateModels(metaschema, baseDir) -} - -func GenerateModels(metaschema *parser.Metaschema, baseDir string) error { - t, err := newTemplate(baseDir) + pkgDir, err := ensurePkgDir(metaschema, baseDir) if err != nil { return err } + return GenerateModels(metaschema, baseDir, pkgDir) +} - packageName := metaschema.GoPackageName() - dir := filepath.Join(baseDir, packageName) - err = os.MkdirAll(dir, os.FileMode(0722)) +func GenerateModels(metaschema *parser.Metaschema, baseDir, pkgDir string) error { + t, err := newTemplate(baseDir) if err != nil { return err } - f, err := os.Create(fmt.Sprintf("%s/generated_models.go", dir)) + + f, err := os.Create(fmt.Sprintf("%s/generated_models.go", pkgDir)) if err != nil { return err } @@ -86,3 +84,9 @@ func newTemplate(baseDir string) (*template.Template, error) { "getImports": getImports, }).Parse(string(tempText)) } + +func ensurePkgDir(metaschema *parser.Metaschema, baseDir string) (string, error) { + dir := filepath.Join(baseDir, metaschema.GoPackageName()) + err := os.MkdirAll(dir, os.FileMode(0722)) + return dir, err +}