From 2b5542e17734565c2c77493077de29e5b2f3f9cf Mon Sep 17 00:00:00 2001
From: Lz <imlangzi@qq.com>
Date: Fri, 26 Apr 2024 11:07:46 +0800
Subject: [PATCH] fix(time): added nullable Time with better json support

---
 CHANGELOG.md |   4 ++
 time.go      |  56 +++++++++++++++++++++++++++
 time_test.go | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 164 insertions(+)
 create mode 100644 time.go
 create mode 100644 time_test.go

diff --git a/CHANGELOG.md b/CHANGELOG.md
index e4ca6df..c3ef8ab 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
 The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
 
+## [1.4.7]
+### Added
+- added `Connector` interface (#43)
+- added nullable `Time` with better json support (#44)
 
 ## [1.4.6] - 2014-04-23
 ### Changed
diff --git a/time.go b/time.go
new file mode 100644
index 0000000..99bd024
--- /dev/null
+++ b/time.go
@@ -0,0 +1,56 @@
+package sqle
+
+import (
+	"database/sql"
+	"database/sql/driver"
+	"encoding/json"
+	"time"
+)
+
+// Time represents a nullable time value.
+type Time struct {
+	sql.NullTime
+}
+
+// NewTime creates a new Time object with the given time and valid flag.
+func NewTime(t time.Time, valid bool) Time {
+	return Time{NullTime: sql.NullTime{Time: t, Valid: valid}}
+}
+
+// Scan implements the [sql.Scanner] interface.
+func (t *Time) Scan(value any) error { // skipcq: GO-W1029
+	return t.NullTime.Scan(value)
+}
+
+// Value implements the [driver.Valuer] interface.
+func (t Time) Value() (driver.Value, error) { // skipcq: GO-W1029
+	return t.NullTime.Value()
+}
+
+// Time returns the underlying time.Time value of the Time struct.
+func (t *Time) Time() time.Time { // skipcq: GO-W1029
+	return t.NullTime.Time
+}
+
+// MarshalJSON implements the json.Marshaler interface
+func (t Time) MarshalJSON() ([]byte, error) { // skipcq: GO-W1029
+	return json.Marshal(t.NullTime.Time)
+}
+
+// UnmarshalJSON implements the json.Unmarshaler interface
+func (t *Time) UnmarshalJSON(data []byte) error { // skipcq: GO-W1029
+	if data == nil {
+		return nil
+	}
+
+	var v time.Time
+	err := json.Unmarshal(data, &v)
+	if err != nil {
+		return err
+	}
+
+	t.NullTime.Time = v
+	t.NullTime.Valid = true
+
+	return nil
+}
diff --git a/time_test.go b/time_test.go
new file mode 100644
index 0000000..a9e9f35
--- /dev/null
+++ b/time_test.go
@@ -0,0 +1,104 @@
+package sqle
+
+import (
+	"database/sql"
+	"encoding/json"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestTimeInSQL(t *testing.T) {
+
+	now := time.Now()
+	d, err := sql.Open("sqlite3", "file::memory:")
+	require.NoError(t, err)
+
+	_, err = d.Exec("CREATE TABLE `times` (`id` id NOT NULL,`created_at` datetime, PRIMARY KEY (`id`))")
+	require.NoError(t, err)
+
+	result, err := d.Exec("INSERT INTO `times`(`id`) VALUES(?)", 10)
+	require.NoError(t, err)
+
+	rows, err := result.RowsAffected()
+	require.NoError(t, err)
+	require.Equal(t, int64(1), rows)
+
+	result, err = d.Exec("INSERT INTO `times`(`id`, `created_at`) VALUES(?, ?)", 20, now)
+	require.NoError(t, err)
+
+	rows, err = result.RowsAffected()
+	require.NoError(t, err)
+	require.Equal(t, int64(1), rows)
+
+	var t10 Time
+	err = d.QueryRow("SELECT `created_at` FROM `times` WHERE id=?", 10).Scan(&t10)
+	require.NoError(t, err)
+
+	require.EqualValues(t, false, t10.Valid)
+
+	var t20 Time
+	err = d.QueryRow("SELECT `created_at` FROM `times` WHERE id=?", 20).Scan(&t20)
+	require.NoError(t, err)
+
+	require.EqualValues(t, true, t20.Valid)
+	require.EqualValues(t, now.UTC(), t20.Time().UTC())
+
+	result, err = d.Exec("INSERT INTO `times`(`id`,`created_at`) VALUES(?, ?)", 11, t10)
+	require.NoError(t, err)
+
+	rows, err = result.RowsAffected()
+	require.NoError(t, err)
+	require.Equal(t, int64(1), rows)
+
+	result, err = d.Exec("INSERT INTO `times`(`id`, `created_at`) VALUES(?, ?)", 21, t20)
+	require.NoError(t, err)
+
+	rows, err = result.RowsAffected()
+	require.NoError(t, err)
+	require.Equal(t, int64(1), rows)
+
+	var t11 Time
+	err = d.QueryRow("SELECT `created_at` FROM `times` WHERE id=?", 11).Scan(&t11)
+	require.NoError(t, err)
+
+	require.EqualValues(t, false, t11.Valid)
+
+	var t21 Time
+	err = d.QueryRow("SELECT `created_at` FROM `times` WHERE id=?", 21).Scan(&t21)
+	require.NoError(t, err)
+
+	require.EqualValues(t, true, t21.Valid)
+	require.EqualValues(t, now.UTC(), t21.Time().UTC())
+
+}
+
+func TestTimeInJSON(t *testing.T) {
+
+	sysTime := time.Now()
+
+	bufSysTime, err := json.Marshal(sysTime)
+	require.NoError(t, err)
+
+	sqleTime := NewTime(sysTime, true)
+
+	bufSqleTime, err := json.Marshal(sqleTime)
+	require.NoError(t, err)
+
+	require.Equal(t, bufSysTime, bufSqleTime)
+
+	var jsSqleTime Time
+	// Unmarshal sqle.Time from time.Time json bytes
+	err = json.Unmarshal(bufSysTime, &jsSqleTime)
+	require.NoError(t, err)
+
+	require.True(t, sysTime.Equal(jsSqleTime.Time()))
+	require.Equal(t, true, jsSqleTime.Valid)
+
+	var jsSysTime time.Time
+	// Unmarshal time.Time from sqle.Time json bytes
+	err = json.Unmarshal(bufSqleTime, &jsSysTime)
+	require.NoError(t, err)
+	require.True(t, sysTime.Equal(jsSysTime))
+}