diff --git a/dbscan/dbscan.go b/dbscan/dbscan.go index a4a0535..3cf20a8 100644 --- a/dbscan/dbscan.go +++ b/dbscan/dbscan.go @@ -53,6 +53,7 @@ type API struct { fieldMapperFn NameMapperFunc scannableTypesOption []interface{} scannableTypesReflect []reflect.Type + allowUnknownColumns bool } // APIOption is a function type that changes API configuration. @@ -61,9 +62,10 @@ type APIOption func(api *API) // NewAPI creates a new API object with provided list of options. func NewAPI(opts ...APIOption) (*API, error) { api := &API{ - structTagKey: "db", - columnSeparator: ".", - fieldMapperFn: SnakeCaseMapper, + structTagKey: "db", + columnSeparator: ".", + fieldMapperFn: SnakeCaseMapper, + allowUnknownColumns: false, } for _, o := range opts { o(api) @@ -129,6 +131,14 @@ func WithScannableTypes(scannableTypes ...interface{}) APIOption { } } +// WithAllowUnknownColumns allows the scanner to ignore db columns that doesn't exist at the destination. +// The default function is to throw an error when a db column ain't found at the destination. +func WithAllowUnknownColumns(allowUnknownColumns bool) APIOption { + return func(api *API) { + api.allowUnknownColumns = allowUnknownColumns + } +} + // ScanAll iterates all rows to the end. After iterating it closes the rows, // and propagates any errors that could pop up. // It expects that destination should be a slice. For each row it scans data and appends it to the destination slice. diff --git a/dbscan/dbscan_test.go b/dbscan/dbscan_test.go index 82a55ae..62d3c78 100644 --- a/dbscan/dbscan_test.go +++ b/dbscan/dbscan_test.go @@ -357,6 +357,23 @@ func TestNewAPI_WithScannableTypes_InvalidInput(t *testing.T) { } } +func TestScanRow_withAllowUnknownColumns_returnsRow(t *testing.T) { + t.Parallel() + rows := queryRows(t, singleRowsQuery) + defer rows.Close() // nolint: errcheck + rows.Next() + + got := &struct{ Foo string }{} + testAPIWithUnknownColumns, err := getAPI(dbscan.WithAllowUnknownColumns(true)) + require.NoError(t, err) + err = testAPIWithUnknownColumns.ScanRow(got, rows) + require.NoError(t, err) + requireNoRowsErrorsAndClose(t, rows) + + expected := struct{ Foo string }{Foo: "foo val"} + assert.Equal(t, expected, *got) +} + func TestMain(m *testing.M) { exitCode := func() int { flag.Parse() diff --git a/dbscan/helpers_test.go b/dbscan/helpers_test.go index f852056..5476689 100644 --- a/dbscan/helpers_test.go +++ b/dbscan/helpers_test.go @@ -40,14 +40,16 @@ func queryRows(t *testing.T, query string) dbscan.Rows { return rows } -func getAPI() (*dbscan.API, error) { - return dbscan.NewAPI( - dbscan.WithScannableTypes( - (*sql.Scanner)(nil), - (*pgtype.TextDecoder)(nil), - (*pgtype.BinaryDecoder)(nil), - ), - ) +func getAPI(opts ...dbscan.APIOption) (*dbscan.API, error) { + if len(opts) < 1 { + opts = []dbscan.APIOption{} + } + opts = append(opts, dbscan.WithScannableTypes( + (*sql.Scanner)(nil), + (*pgtype.TextDecoder)(nil), + (*pgtype.BinaryDecoder)(nil), + )) + return dbscan.NewAPI(opts...) } func scan(t *testing.T, dst interface{}, rows dbscan.Rows) error { diff --git a/dbscan/rowscanner.go b/dbscan/rowscanner.go index b24c642..5b96706 100644 --- a/dbscan/rowscanner.go +++ b/dbscan/rowscanner.go @@ -125,6 +125,11 @@ func (rs *RowScanner) scanStruct(structValue reflect.Value) error { for i, column := range rs.columns { fieldIndex, ok := rs.columnToFieldIndex[column] if !ok { + if rs.api.allowUnknownColumns { + var tmp interface{} + scans[i] = &tmp + continue + } return errors.Errorf( "scany: column: '%s': no corresponding field found, or it's unexported in %v", column, structValue.Type(),