From fda565702211fb0fcd01ae92009d840ce8d1acc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=80=97=E5=AD=90?= Date: Fri, 12 Jul 2024 20:10:40 +0800 Subject: [PATCH] fix: session driver lock problem (#558) * fix: session driver lock problem * fix: start_session test * fix: MakeSession test --- foundation/application_test.go | 2 ++ session/manager.go | 14 +++++++------- session/manager_test.go | 7 ++----- session/middleware/start_session_test.go | 15 +++++++-------- 4 files changed, 18 insertions(+), 20 deletions(-) diff --git a/foundation/application_test.go b/foundation/application_test.go index e5f530a88..406d87e05 100644 --- a/foundation/application_test.go +++ b/foundation/application_test.go @@ -377,6 +377,8 @@ func (s *ApplicationTestSuite) TestMakeSchedule() { func (s *ApplicationTestSuite) TestMakeSession() { mockConfig := &configmocks.Config{} + mockConfig.On("GetInt", "session.lifetime").Return(120).Once() + mockConfig.On("GetString", "session.files").Return("storage/framework/sessions").Once() s.app.Singleton(frameworkconfig.Binding, func(app foundation.Application) (any, error) { return mockConfig, nil diff --git a/session/manager.go b/session/manager.go index b2cc01d7d..ed3af0854 100644 --- a/session/manager.go +++ b/session/manager.go @@ -11,16 +11,16 @@ import ( type Manager struct { config config.Config - customDrivers map[string]func() sessioncontract.Driver - drivers map[string]func() sessioncontract.Driver + customDrivers map[string]sessioncontract.Driver + drivers map[string]sessioncontract.Driver json foundation.Json } func NewManager(config config.Config, json foundation.Json) *Manager { manager := &Manager{ config: config, - customDrivers: make(map[string]func() sessioncontract.Driver), - drivers: make(map[string]func() sessioncontract.Driver), + customDrivers: make(map[string]sessioncontract.Driver), + drivers: make(map[string]sessioncontract.Driver), json: json, } manager.registerDrivers() @@ -51,11 +51,11 @@ func (m *Manager) Driver(name ...string) (sessioncontract.Driver, error) { m.drivers[driverName] = m.customDrivers[driverName] } - return m.drivers[driverName](), nil + return m.drivers[driverName], nil } func (m *Manager) Extend(driver string, handler func() sessioncontract.Driver) sessioncontract.Manager { - m.customDrivers[driver] = handler + m.customDrivers[driver] = handler() return m } @@ -69,5 +69,5 @@ func (m *Manager) createFileDriver() sessioncontract.Driver { } func (m *Manager) registerDrivers() { - m.drivers["file"] = m.createFileDriver + m.drivers["file"] = m.createFileDriver() } diff --git a/session/manager_test.go b/session/manager_test.go index f3f7b71f4..5e5bda8cf 100644 --- a/session/manager_test.go +++ b/session/manager_test.go @@ -25,14 +25,13 @@ func TestManagerTestSuite(t *testing.T) { func (m *ManagerTestSuite) SetupTest() { m.mockConfig = mockconfig.NewConfig(m.T()) + m.mockConfig.On("GetInt", "session.lifetime").Return(120).Once() + m.mockConfig.On("GetString", "session.files").Return("storage/framework/sessions").Once() m.manager = m.getManager() m.json = json.NewJson() } func (m *ManagerTestSuite) TestDriver() { - m.mockConfig.On("GetInt", "session.lifetime").Return(120).Once() - m.mockConfig.On("GetString", "session.files").Return("storage/framework/sessions").Once() - // provide driver name driver, err := m.manager.Driver("file") m.Nil(err) @@ -41,8 +40,6 @@ func (m *ManagerTestSuite) TestDriver() { // provide no driver name m.mockConfig.On("GetString", "session.driver").Return("file").Once() - m.mockConfig.On("GetInt", "session.lifetime").Return(120).Once() - m.mockConfig.On("GetString", "session.files").Return("storage/framework/sessions").Once() driver, err = m.manager.Driver() m.Nil(err) diff --git a/session/middleware/start_session_test.go b/session/middleware/start_session_test.go index d559d5ed8..a697cdfe8 100644 --- a/session/middleware/start_session_test.go +++ b/session/middleware/start_session_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/require" "github.com/goravel/framework/contracts/filesystem" - "github.com/goravel/framework/contracts/foundation" contractshttp "github.com/goravel/framework/contracts/http" contractsession "github.com/goravel/framework/contracts/session" "github.com/goravel/framework/contracts/validation" @@ -22,10 +21,8 @@ import ( "github.com/goravel/framework/support/file" ) -func testHttpSessionMiddleware(next nethttp.Handler, mockConfig *configmocks.Config, json foundation.Json) nethttp.Handler { +func testHttpSessionMiddleware(next nethttp.Handler, mockConfig *configmocks.Config) nethttp.Handler { return nethttp.HandlerFunc(func(w nethttp.ResponseWriter, r *nethttp.Request) { - session.ConfigFacade = mockConfig - session.SessionFacade = session.NewManager(mockConfig, json) mockConfigFacade(mockConfig) StartSession()(NewTestContext(r.Context(), next, w, r)) }) @@ -33,8 +30,7 @@ func testHttpSessionMiddleware(next nethttp.Handler, mockConfig *configmocks.Con func mockConfigFacade(mockConfig *configmocks.Config) { mockConfig.On("GetString", "session.driver").Return("file").Twice() - mockConfig.On("GetInt", "session.lifetime").Return(60).Times(3) - mockConfig.On("GetString", "session.files").Return("sessions").Once() + mockConfig.On("GetInt", "session.lifetime").Return(60).Twice() mockConfig.On("GetString", "session.cookie").Return("goravel_session").Once() mockConfig.On("GetString", "session.path").Return("/").Once() mockConfig.On("GetString", "session.domain").Return("").Once() @@ -46,7 +42,10 @@ func mockConfigFacade(mockConfig *configmocks.Config) { func TestStartSession(t *testing.T) { mockConfig := &configmocks.Config{} - j := json.NewJson() + session.ConfigFacade = mockConfig + mockConfig.On("GetInt", "session.lifetime").Return(120).Once() + mockConfig.On("GetString", "session.files").Return("storage/framework/sessions").Once() + session.SessionFacade = session.NewManager(mockConfig, json.NewJson()) server := httptest.NewServer(testHttpSessionMiddleware(nethttp.HandlerFunc(func(w nethttp.ResponseWriter, r *nethttp.Request) { switch r.URL.Path { case "/add": @@ -59,7 +58,7 @@ func TestStartSession(t *testing.T) { assert.Equal(t, "bar", s.Get("foo")) assert.Equal(t, "qux", s.Get("baz")) } - }), mockConfig, j)) + }), mockConfig)) defer server.Close() client := &nethttp.Client{}