Skip to content

Commit

Permalink
fix(spanner): MarshalJSON function caused errors for certain values (#…
Browse files Browse the repository at this point in the history
…9063)

Fix inspired by the bigquery package
  • Loading branch information
danysousa authored Nov 30, 2023
1 parent 0685da5 commit afe7c98
Showing 1 changed file with 17 additions and 37 deletions.
54 changes: 17 additions & 37 deletions spanner/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,7 @@ func (n NullInt64) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullInt64.
func (n NullInt64) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%v", n.Int64)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Int64)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullInt64.
Expand Down Expand Up @@ -270,10 +267,7 @@ func (n NullString) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullString.
func (n NullString) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.StringVal)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.StringVal)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullString.
Expand Down Expand Up @@ -358,10 +352,7 @@ func (n NullFloat64) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullFloat64.
func (n NullFloat64) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%v", n.Float64)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Float64)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullFloat64.
Expand Down Expand Up @@ -441,10 +432,7 @@ func (n NullBool) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullBool.
func (n NullBool) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%v", n.Bool)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Bool)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullBool.
Expand Down Expand Up @@ -524,10 +512,7 @@ func (n NullTime) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullTime.
func (n NullTime) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.String())), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Time)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullTime.
Expand Down Expand Up @@ -612,10 +597,7 @@ func (n NullDate) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullDate.
func (n NullDate) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.String())), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Date)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullDate.
Expand Down Expand Up @@ -701,7 +683,7 @@ func (n NullNumeric) String() string {
// MarshalJSON implements json.Marshaler.MarshalJSON for NullNumeric.
func (n NullNumeric) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", NumericString(&n.Numeric))), nil
return json.Marshal(NumericString(&n.Numeric))
}
return jsonNullBytes, nil
}
Expand Down Expand Up @@ -800,10 +782,7 @@ func (n NullJSON) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullJSON.
func (n NullJSON) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Value)
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Value)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullJSON.
Expand Down Expand Up @@ -851,10 +830,7 @@ func (n PGNumeric) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for PGNumeric.
func (n PGNumeric) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.Numeric)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Numeric)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGNumeric.
Expand Down Expand Up @@ -912,10 +888,7 @@ func (n PGJsonB) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for PGJsonB.
func (n PGJsonB) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Value)
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Value)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGJsonB.
Expand All @@ -937,6 +910,13 @@ func (n *PGJsonB) UnmarshalJSON(payload []byte) error {
return nil
}

func nulljson(valid bool, v interface{}) ([]byte, error) {
if !valid {
return jsonNullBytes, nil
}
return json.Marshal(v)
}

// GenericColumnValue represents the generic encoded value and type of the
// column. See google.spanner.v1.ResultSet proto for details. This can be
// useful for proxying query results when the result types are not known in
Expand Down

0 comments on commit afe7c98

Please sign in to comment.