Skip to content

Commit

Permalink
refactor: minor tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane committed Nov 11, 2023
1 parent cf82275 commit 6de3643
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 47 deletions.
47 changes: 21 additions & 26 deletions env.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package env provides an API for loading environment variables into structs.
// Package env implements loading environment variables into a config struct.
package env

import (
Expand All @@ -8,55 +8,51 @@ import (
"strings"
)

// Options are options for the [Load] function.
// Options are the options for the [Load] function.
type Options struct {
Source Source // The source of environment variables. The default is [OS].
SliceSep string // The separator used to parse slice values. The default is space.
}

// NotSetError is returned when environment variables are marked as required but not set.
type NotSetError struct {
// The names of the missing required environment variables.
Names []string
Names []string // The names of the missing environment variables.
}

// Error implements the error interface.
func (e *NotSetError) Error() string {
return fmt.Sprintf("env: %v are required but not set", e.Names)
if len(e.Names) == 1 {
return fmt.Sprintf("env: %s is required but not set", e.Names[0])
}
return fmt.Sprintf("env: %s are required but not set", strings.Join(e.Names, " "))
}

// Load loads environment variables into the provided struct using the [OS] [Source].
// Load loads environment variables into the given struct.
// cfg must be a non-nil struct pointer, otherwise Load panics.
// If opts is nil, the default [Options] are used.
//
// The struct fields must have the `env:"VAR"` struct tag, where VAR is the name of the corresponding environment variable.
// The struct fields must have the `env:"VAR"` struct tag,
// where VAR is the name of the corresponding environment variable.
// Unexported fields are ignored.
//
// # Supported types
//
// The following types are supported:
// - int (any kind)
// - float (any kind)
// - bool
// - string
// - [time.Duration]
// - [encoding.TextUnmarshaler]
// - slices of any type above (space is the default separator for values)
//
// See the [strconv].Parse* functions for parsing rules.
// Implementing the [encoding.TextUnmarshaler] interface is enough to use any user-defined type.
// Nested structs of any depth level are supported, only the leaves of the config tree must have the `env` tag.
//
// # Default values
// - slices of any type above
// - nested structs of any depth
//
// Default values can be specified either using the `default` struct tag (has a higher priority) or by initializing the struct fields directly.
// See the [strconv].Parse* functions for the parsing rules.
// User-defined types can be used by implementing the [encoding.TextUnmarshaler] interface.
//
// # Per-variable options
//
// The name of the environment variable can be followed by comma-separated options in the form of `env:"VAR,option1,option2,..."`:
// Default values can be specified using the `default:"VALUE"` struct tag.
//
// The name of an environment variable can be followed by comma-separated options:
// - required: marks the environment variable as required
// - expand: expands the value of the environment variable using [os.Expand]
//
// If environment variables are marked as required but not set, an error of type [NotSetError] will be returned.
func Load(cfg any, opts *Options) error {
if opts == nil {
opts = new(Options)
Expand Down Expand Up @@ -113,20 +109,19 @@ func parseVars(v reflect.Value) []Var {
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanSet() {
continue // skip unexported fields.
continue
}

// special case: a nested struct, parse its fields recursively.
if kindOf(field, reflect.Struct) && !implements(field, unmarshalerIface) {
nested := parseVars(field)
vars = append(vars, nested...)
vars = append(vars, parseVars(field)...)
continue
}

sf := v.Type().Field(i)
value, ok := sf.Tag.Lookup("env")
if !ok {
continue // skip fields without the `env` tag.
continue
}

parts := strings.Split(value, ",")
Expand Down
2 changes: 1 addition & 1 deletion env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestLoad(t *testing.T) {
})

t.Run("unsupported type", func(t *testing.T) {
m := env.Map{"FOO": "1 + 2i"}
m := env.Map{"FOO": "1+2i"}

var cfg struct {
Foo complex64 `env:"FOO"`
Expand Down
14 changes: 6 additions & 8 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,19 @@ func ExampleLoad_nestedStruct() {
}

func ExampleLoad_required() {
os.Unsetenv("HOST")
os.Unsetenv("PORT")

var cfg struct {
Host string `env:"HOST,required"`
Port int `env:"PORT,required"`
Port int `env:"PORT,required"`
}
if err := env.Load(&cfg, nil); err != nil {
var notSetErr *env.NotSetError
if errors.As(err, &notSetErr) {
fmt.Println(notSetErr.Names)
fmt.Println(notSetErr)
}
}

// Output: [HOST PORT]
// Output: env: PORT is required but not set
}

func ExampleLoad_expand() {
Expand Down Expand Up @@ -100,12 +98,12 @@ func ExampleLoad_source() {
}

func ExampleLoad_sliceSeparator() {
os.Setenv("PORTS", "8080;8081;8082")
os.Setenv("PORTS", "8080,8081,8082")

var cfg struct {
Ports []int `env:"PORTS"`
}
if err := env.Load(&cfg, &env.Options{SliceSep: ";"}); err != nil {
if err := env.Load(&cfg, &env.Options{SliceSep: ","}); err != nil {
fmt.Println(err)
}

Expand All @@ -129,7 +127,7 @@ func ExampleUsage() {
env.Usage(&cfg, os.Stdout)
}

// Output: env: [DB_HOST DB_PORT] are required but not set
// Output: env: DB_HOST DB_PORT are required but not set
// Usage:
// DB_HOST string required database host
// DB_PORT int required database port
Expand Down
16 changes: 4 additions & 12 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ var (
unmarshalerIface = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
)

// typeOf reports whether v's type is one of the provided types.
func typeOf(v reflect.Value, types ...reflect.Type) bool {
for _, t := range types {
if t == v.Type() {
Expand All @@ -23,7 +22,6 @@ func typeOf(v reflect.Value, types ...reflect.Type) bool {
return false
}

// kindOf reports whether v's kind is one of the provided kinds.
func kindOf(v reflect.Value, kinds ...reflect.Kind) bool {
for _, k := range kinds {
if k == v.Kind() {
Expand All @@ -33,22 +31,19 @@ func kindOf(v reflect.Value, kinds ...reflect.Kind) bool {
return false
}

// implements reports whether v's type implements one of the provided interfaces.
func implements(v reflect.Value, ifaces ...reflect.Type) bool {
for _, iface := range ifaces {
if t := v.Type(); t.Implements(iface) || reflect.PtrTo(v.Type()).Implements(iface) {
if t := v.Type(); t.Implements(iface) || reflect.PtrTo(t).Implements(iface) {
return true
}
}
return false
}

// structPtr reports whether v is a non-nil struct pointer.
func structPtr(v reflect.Value) bool {
return v.IsValid() && v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct && !v.IsNil()
}

// setValue parses s based on v's type/kind and sets v's underlying value to the result.
func setValue(v reflect.Value, s string) error {
switch {
case typeOf(v, durationType):
Expand All @@ -71,8 +66,7 @@ func setValue(v reflect.Value, s string) error {
}

func setInt(v reflect.Value, s string) error {
bits := v.Type().Bits()
i, err := strconv.ParseInt(s, 10, bits)
i, err := strconv.ParseInt(s, 10, v.Type().Bits())
if err != nil {
return fmt.Errorf("parsing int: %w", err)
}
Expand All @@ -81,8 +75,7 @@ func setInt(v reflect.Value, s string) error {
}

func setUint(v reflect.Value, s string) error {
bits := v.Type().Bits()
u, err := strconv.ParseUint(s, 10, bits)
u, err := strconv.ParseUint(s, 10, v.Type().Bits())
if err != nil {
return fmt.Errorf("parsing uint: %w", err)
}
Expand All @@ -91,8 +84,7 @@ func setUint(v reflect.Value, s string) error {
}

func setFloat(v reflect.Value, s string) error {
bits := v.Type().Bits()
f, err := strconv.ParseFloat(s, bits)
f, err := strconv.ParseFloat(s, v.Type().Bits())
if err != nil {
return fmt.Errorf("parsing float: %w", err)
}
Expand Down

0 comments on commit 6de3643

Please sign in to comment.