Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(spanner): MarshalJSON function caused errors for certain values #9063

Merged
merged 2 commits into from
Nov 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 17 additions & 37 deletions spanner/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,7 @@ func (n NullInt64) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullInt64.
func (n NullInt64) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%v", n.Int64)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Int64)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullInt64.
Expand Down Expand Up @@ -270,10 +267,7 @@ func (n NullString) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullString.
func (n NullString) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.StringVal)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.StringVal)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullString.
Expand Down Expand Up @@ -358,10 +352,7 @@ func (n NullFloat64) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullFloat64.
func (n NullFloat64) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%v", n.Float64)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Float64)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullFloat64.
Expand Down Expand Up @@ -441,10 +432,7 @@ func (n NullBool) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullBool.
func (n NullBool) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%v", n.Bool)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Bool)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullBool.
Expand Down Expand Up @@ -524,10 +512,7 @@ func (n NullTime) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullTime.
func (n NullTime) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.String())), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Time)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullTime.
Expand Down Expand Up @@ -612,10 +597,7 @@ func (n NullDate) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullDate.
func (n NullDate) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.String())), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Date)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullDate.
Expand Down Expand Up @@ -701,7 +683,7 @@ func (n NullNumeric) String() string {
// MarshalJSON implements json.Marshaler.MarshalJSON for NullNumeric.
func (n NullNumeric) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", NumericString(&n.Numeric))), nil
return json.Marshal(NumericString(&n.Numeric))
}
return jsonNullBytes, nil
}
Expand Down Expand Up @@ -800,10 +782,7 @@ func (n NullJSON) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for NullJSON.
func (n NullJSON) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Value)
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Value)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for NullJSON.
Expand Down Expand Up @@ -851,10 +830,7 @@ func (n PGNumeric) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for PGNumeric.
func (n PGNumeric) MarshalJSON() ([]byte, error) {
if n.Valid {
return []byte(fmt.Sprintf("%q", n.Numeric)), nil
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Numeric)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGNumeric.
Expand Down Expand Up @@ -912,10 +888,7 @@ func (n PGJsonB) String() string {

// MarshalJSON implements json.Marshaler.MarshalJSON for PGJsonB.
func (n PGJsonB) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Value)
}
return jsonNullBytes, nil
return nulljson(n.Valid, n.Value)
}

// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGJsonB.
Expand All @@ -937,6 +910,13 @@ func (n *PGJsonB) UnmarshalJSON(payload []byte) error {
return nil
}

func nulljson(valid bool, v interface{}) ([]byte, error) {
if !valid {
return jsonNullBytes, nil
}
return json.Marshal(v)
}

// GenericColumnValue represents the generic encoded value and type of the
// column. See google.spanner.v1.ResultSet proto for details. This can be
// useful for proxying query results when the result types are not known in
Expand Down
Loading