diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index e2752ad..8a1d626 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1,13 +1 @@ -# These are supported funding model platforms - -#github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] -#patreon: # Replace with a single Patreon username -#open_collective: # Replace with a single Open Collective username -#ko_fi: # Replace with a single Ko-fi username -#tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel -#community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry -#liberapay: markus621 -#issuehunt: # Replace with a single IssueHunt username -#otechie: # Replace with a single Otechie username -#lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry -custom: ['https://www.dewep.pro/donate'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] +#custom: ['https://osspkg.com/donate'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.golangci.yml b/.golangci.yml index 3e0ab98..77506ff 100755 --- a/.golangci.yml +++ b/.golangci.yml @@ -346,6 +346,12 @@ linters-settings: # Tab width in spaces. # Default: 1 tab-width: 1 + staticcheck: + # Deprecated: use the global `run.go` instead. + go: "1.15" + # SAxxxx checks in https://staticcheck.io/docs/configuration/options/#checks + # Default: ["*"] + checks: [ "*", "-SA1019" ] linters: disable-all: true @@ -362,7 +368,6 @@ linters: - unused - prealloc - durationcheck -# - nolintlint - staticcheck - makezero - nilerr diff --git a/examples/demo-ws-cli/main.go b/examples/demo-ws-cli/main.go deleted file mode 100644 index 4179b61..0000000 --- a/examples/demo-ws-cli/main.go +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. - * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. - */ - -package main - -import ( - "context" - "fmt" - "time" - - "github.com/osspkg/goppy" - "github.com/osspkg/goppy/plugins" - "github.com/osspkg/goppy/plugins/web" -) - -func main() { - app := goppy.New() - app.Plugins( - web.WithWebsocketClient(), - ) - app.Plugins( - plugins.Plugin{ - Inject: NewController, - Resolve: func(c *Controller, ws web.WebsocketClient) error { - wsc, err := ws.Create(context.TODO(), "ws://127.0.0.1:8088/ws") - if err != nil { - return err - } - - wsc.Event(c.EventListener, 99) - go c.Ticker(wsc.Encode) - - time.AfterFunc(30*time.Second, func() { - wsc.Close() - }) - - go wsc.Run() - - return nil - }, - }, - ) - app.Run() -} - -type Controller struct{} - -func NewController() *Controller { - return &Controller{} -} - -func (v *Controller) Ticker(call func(id uint, in interface{})) { - t := time.NewTicker(time.Second * 3) - defer t.Stop() - - for { - select { - case tt := <-t.C: - call(99, tt.Format(time.RFC3339)) - } - } -} - -func (v *Controller) EventListener(d web.WebsocketEventer, c web.WebsocketClientProcessor) error { - fmt.Println("EventListener", c.ConnectID(), d.UniqueID(), d.EventID()) - return nil -} diff --git a/examples/demo-ws/index.html b/examples/demo-ws/index.html deleted file mode 100644 index 18de050..0000000 --- a/examples/demo-ws/index.html +++ /dev/null @@ -1,67 +0,0 @@ - - - - -Close socket -
-
- -
- -
- -
-
- - -
-
- - \ No newline at end of file diff --git a/examples/demo-ws/main.go b/examples/demo-ws/main.go deleted file mode 100644 index e08553e..0000000 --- a/examples/demo-ws/main.go +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. - * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. - */ - -package main - -import ( - "fmt" - "sync" - "time" - - "github.com/osspkg/goppy" - "github.com/osspkg/goppy/plugins" - "github.com/osspkg/goppy/plugins/web" -) - -func main() { - app := goppy.New() - app.Plugins( - web.WithHTTP(), - web.WithWebsocketServer(), - ) - app.Plugins( - plugins.Plugin{ - Inject: NewController, - Resolve: func(routes web.RouterPool, c *Controller, ws web.WebsocketServer) { - router := routes.Main() - router.Use(web.ThrottlingMiddleware(100)) - - ws.Event(c.Event99, 99) - ws.Event(c.OneEvent, 1, 2) - ws.Event(c.MultiEvent, 11, 13) - - router.Get("/ws", ws.Handling) - }, - }, - ) - app.Run() -} - -type Controller struct { - list map[string]web.WebsocketServerProcessor - mux sync.RWMutex -} - -func NewController() *Controller { - c := &Controller{ - list: make(map[string]web.WebsocketServerProcessor), - } - go c.Timer() - return c -} - -func (v *Controller) Event99(ev web.WebsocketEventer, c web.WebsocketServerProcessor) error { - var data string - if err := ev.Decode(&data); err != nil { - return err - } - c.EncodeEvent(ev, &data) - fmt.Println(c.ConnectID(), "Event99", ev.EventID(), ev.UniqueID()) - return nil -} - -func (v *Controller) OneEvent(ev web.WebsocketEventer, c web.WebsocketServerProcessor) error { - list := make([]int, 0) - if err := ev.Decode(&list); err != nil { - return err - } - list = append(list, 10, 19, 17, 15) - c.EncodeEvent(ev, &list) - fmt.Println(c.ConnectID(), "OneEvent", ev.EventID(), ev.UniqueID()) - return nil -} - -func (v *Controller) Timer() { - t := time.NewTicker(time.Second * 3) - defer t.Stop() - - for { - select { - case tt := <-t.C: - v.muxRLock(func() { - for _, p := range v.list { - p.Encode(12, tt.Format(time.RFC3339)) - fmt.Println("Timer", p.ConnectID()) - } - }) - } - } -} - -func (v *Controller) MultiEvent(d web.WebsocketEventer, c web.WebsocketServerProcessor) error { - switch d.EventID() { - case 11: - v.muxLock(func() { - v.list[c.ConnectID()] = c - fmt.Println("MultiEvent Add", c.ConnectID()) - }) - - c.OnClose(func(cid string) { - v.muxLock(func() { - delete(v.list, cid) - fmt.Println("MultiEvent Close", cid) - }) - }) - - case 13: - v.muxLock(func() { - delete(v.list, c.ConnectID()) - fmt.Println("MultiEvent Del", c.ConnectID()) - }) - - } - return nil -} - -func (v *Controller) muxLock(cb func()) { - v.mux.Lock() - cb() - v.mux.Unlock() -} - -func (v *Controller) muxRLock(cb func()) { - v.mux.RLock() - cb() - v.mux.RUnlock() -} diff --git a/examples/demo-basic/config.yaml b/examples/goppy/demo-basic/config.yaml similarity index 100% rename from examples/demo-basic/config.yaml rename to examples/goppy/demo-basic/config.yaml diff --git a/examples/demo-basic/main.go b/examples/goppy/demo-basic/main.go similarity index 96% rename from examples/demo-basic/main.go rename to examples/goppy/demo-basic/main.go index f05174f..dc20814 100644 --- a/examples/demo-basic/main.go +++ b/examples/goppy/demo-basic/main.go @@ -9,10 +9,10 @@ import ( "fmt" "os" - "github.com/osspkg/go-sdk/console" "github.com/osspkg/goppy" "github.com/osspkg/goppy/plugins" "github.com/osspkg/goppy/plugins/web" + "github.com/osspkg/goppy/sdk/console" ) func main() { diff --git a/examples/demo-database/config.yaml b/examples/goppy/demo-database/config.yaml similarity index 100% rename from examples/demo-database/config.yaml rename to examples/goppy/demo-database/config.yaml diff --git a/examples/demo-database/main.go b/examples/goppy/demo-database/main.go similarity index 100% rename from examples/demo-database/main.go rename to examples/goppy/demo-database/main.go diff --git a/examples/demo-geoip/config.yaml b/examples/goppy/demo-geoip/config.yaml similarity index 100% rename from examples/demo-geoip/config.yaml rename to examples/goppy/demo-geoip/config.yaml diff --git a/examples/demo-geoip/main.go b/examples/goppy/demo-geoip/main.go similarity index 100% rename from examples/demo-geoip/main.go rename to examples/goppy/demo-geoip/main.go diff --git a/examples/demo-migrate-mysql/000_init.sql b/examples/goppy/demo-migrate-mysql/000_init.sql similarity index 100% rename from examples/demo-migrate-mysql/000_init.sql rename to examples/goppy/demo-migrate-mysql/000_init.sql diff --git a/examples/demo-migrate-mysql/001_demo2_table.sql b/examples/goppy/demo-migrate-mysql/001_demo2_table.sql similarity index 100% rename from examples/demo-migrate-mysql/001_demo2_table.sql rename to examples/goppy/demo-migrate-mysql/001_demo2_table.sql diff --git a/examples/demo-migrate-mysql/config.yaml b/examples/goppy/demo-migrate-mysql/config.yaml similarity index 100% rename from examples/demo-migrate-mysql/config.yaml rename to examples/goppy/demo-migrate-mysql/config.yaml diff --git a/examples/demo-migrate-mysql/main.go b/examples/goppy/demo-migrate-mysql/main.go similarity index 100% rename from examples/demo-migrate-mysql/main.go rename to examples/goppy/demo-migrate-mysql/main.go diff --git a/examples/demo-migrate-pgsql/000_init.sql b/examples/goppy/demo-migrate-pgsql/000_init.sql similarity index 100% rename from examples/demo-migrate-pgsql/000_init.sql rename to examples/goppy/demo-migrate-pgsql/000_init.sql diff --git a/examples/demo-migrate-pgsql/001_demo2_table.sql b/examples/goppy/demo-migrate-pgsql/001_demo2_table.sql similarity index 100% rename from examples/demo-migrate-pgsql/001_demo2_table.sql rename to examples/goppy/demo-migrate-pgsql/001_demo2_table.sql diff --git a/examples/demo-migrate-pgsql/config.yaml b/examples/goppy/demo-migrate-pgsql/config.yaml similarity index 100% rename from examples/demo-migrate-pgsql/config.yaml rename to examples/goppy/demo-migrate-pgsql/config.yaml diff --git a/examples/demo-migrate-pgsql/main.go b/examples/goppy/demo-migrate-pgsql/main.go similarity index 100% rename from examples/demo-migrate-pgsql/main.go rename to examples/goppy/demo-migrate-pgsql/main.go diff --git a/examples/demo-migrate-sqlite/000_init.sql b/examples/goppy/demo-migrate-sqlite/000_init.sql similarity index 100% rename from examples/demo-migrate-sqlite/000_init.sql rename to examples/goppy/demo-migrate-sqlite/000_init.sql diff --git a/examples/demo-migrate-sqlite/001_demo2_table.sql b/examples/goppy/demo-migrate-sqlite/001_demo2_table.sql similarity index 100% rename from examples/demo-migrate-sqlite/001_demo2_table.sql rename to examples/goppy/demo-migrate-sqlite/001_demo2_table.sql diff --git a/examples/demo-migrate-sqlite/config.yaml b/examples/goppy/demo-migrate-sqlite/config.yaml similarity index 100% rename from examples/demo-migrate-sqlite/config.yaml rename to examples/goppy/demo-migrate-sqlite/config.yaml diff --git a/examples/demo-migrate-sqlite/main.go b/examples/goppy/demo-migrate-sqlite/main.go similarity index 100% rename from examples/demo-migrate-sqlite/main.go rename to examples/goppy/demo-migrate-sqlite/main.go diff --git a/examples/demo-oauth/config.yaml b/examples/goppy/demo-oauth/config.yaml similarity index 100% rename from examples/demo-oauth/config.yaml rename to examples/goppy/demo-oauth/config.yaml diff --git a/examples/demo-oauth/main.go b/examples/goppy/demo-oauth/main.go similarity index 100% rename from examples/demo-oauth/main.go rename to examples/goppy/demo-oauth/main.go diff --git a/examples/demo-unix/config.yaml b/examples/goppy/demo-unix/config.yaml similarity index 100% rename from examples/demo-unix/config.yaml rename to examples/goppy/demo-unix/config.yaml diff --git a/examples/demo-unix/main.go b/examples/goppy/demo-unix/main.go similarity index 89% rename from examples/demo-unix/main.go rename to examples/goppy/demo-unix/main.go index 030ed18..5431b1c 100644 --- a/examples/demo-unix/main.go +++ b/examples/goppy/demo-unix/main.go @@ -22,7 +22,7 @@ func main() { ) app.Plugins( plugins.Plugin{ - Resolve: func(s unix.Server, c unix.Client) error { + Resolve: func(s unix.Server, c unix.Client, conf *unix.Config) error { s.Command("demo", func(bytes []byte) ([]byte, error) { fmt.Println("<", string(bytes)) @@ -30,7 +30,7 @@ func main() { }) time.AfterFunc(time.Second*5, func() { - cc, err := c.Create("/tmp/demo-unix.sock") + cc, err := c.Create(conf.Path) if err != nil { panic(err) } @@ -49,7 +49,7 @@ func main() { }) time.AfterFunc(time.Second*15, func() { - cc, err := c.Create("/tmp/demo-unix.sock") + cc, err := c.Create(conf.Path) if err != nil { panic(err) } diff --git a/examples/demo-ws-cli/config.yaml b/examples/goppy/demo-ws-cli/config.yaml similarity index 100% rename from examples/demo-ws-cli/config.yaml rename to examples/goppy/demo-ws-cli/config.yaml diff --git a/examples/goppy/demo-ws-cli/main.go b/examples/goppy/demo-ws-cli/main.go new file mode 100644 index 0000000..d4a3b90 --- /dev/null +++ b/examples/goppy/demo-ws-cli/main.go @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "encoding/json" + "fmt" + "time" + + "github.com/osspkg/goppy" + "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/plugins/web" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/netutil/websocket" +) + +func main() { + application := goppy.New() + application.Plugins( + web.WithWebsocketClient(), + ) + application.Plugins( + plugins.Plugin{ + Inject: func() *Controller { + return &Controller{} + }, + Resolve: func(c *Controller, ctx app.Context, ws web.WebsocketClient) { + wsc := ws.Create("ws://127.0.0.1:8088/ws") + wsc.SetHandler(c.EventListener, 99, 1, 65000) + go c.Ticker(wsc.Encode) + wsc.OnClose(func(cid string) { + fmt.Println("server close connect") + ctx.Close() + }) + }, + }, + ) + application.Run() +} + +type Controller struct{} + +func NewController() *Controller { + return &Controller{} +} + +func (v *Controller) Ticker(call func(id websocket.EventID, in interface{})) { + t := time.NewTicker(time.Second * 3) + defer t.Stop() + + for { + select { + case <-t.C: + call(1, []int{0}) + } + } +} + +func (v *Controller) EventListener(w websocket.CRequest, r websocket.CResponse, m websocket.CMeta) { + var vv json.RawMessage + if err := r.Decode(&vv); err != nil { + fmt.Println(err) + } + fmt.Println("EventListener", m.ConnectID(), r.EventID(), string(vv)) +} diff --git a/examples/demo-ws/config.yaml b/examples/goppy/demo-ws/config.yaml similarity index 100% rename from examples/demo-ws/config.yaml rename to examples/goppy/demo-ws/config.yaml diff --git a/examples/goppy/demo-ws/index.html b/examples/goppy/demo-ws/index.html new file mode 100644 index 0000000..218798a --- /dev/null +++ b/examples/goppy/demo-ws/index.html @@ -0,0 +1,71 @@ + + + + + + WS + + + + + +Close socket +
+
+ +
+ +
+ +
+
+ + +
+
+ + \ No newline at end of file diff --git a/examples/goppy/demo-ws/main.go b/examples/goppy/demo-ws/main.go new file mode 100644 index 0000000..7045fe6 --- /dev/null +++ b/examples/goppy/demo-ws/main.go @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "fmt" + "sync" + "time" + + "github.com/osspkg/goppy/sdk/netutil/websocket" + + "github.com/osspkg/goppy" + "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/plugins/web" +) + +func main() { + app := goppy.New() + app.Plugins( + web.WithHTTP(), + web.WithWebsocketServer(), + ) + app.Plugins( + plugins.Plugin{ + Inject: func(ws web.WebsocketServer) *Controller { + return NewController(ws) + }, + Resolve: func(routes web.RouterPool, c *Controller, ws web.WebsocketServer) { + router := routes.Main() + router.Use(web.ThrottlingMiddleware(100)) + + ws.SetHandler(c.Event99, 99) + ws.SetHandler(c.OneEvent, 1, 2) + ws.SetHandler(c.MultiEvent, 11, 13) + + router.Get("/ws", func(ctx web.Context) { + ws.Handling(ctx.Response(), ctx.Request()) + }) + }, + }, + ) + app.Run() +} + +type ( + sender interface { + SendEvent(eid websocket.EventID, m interface{}, cids ...string) + Broadcast(eid websocket.EventID, m interface{}) + } + Controller struct { + list map[string]struct{} + sender sender + mux sync.RWMutex + } +) + +func NewController(s sender) *Controller { + c := &Controller{ + list: make(map[string]struct{}), + sender: s, + } + go c.Timer() + return c +} + +func (v *Controller) Event99(w websocket.Response, r websocket.Request, m websocket.Meta) error { + var data string + if err := r.Decode(&data); err != nil { + return err + } + w.Encode(&data) + fmt.Println(m.ConnectID(), "Event99", r.EventID()) + return nil +} + +func (v *Controller) OneEvent(w websocket.Response, r websocket.Request, m websocket.Meta) error { + list := make([]int, 0) + if err := r.Decode(&list); err != nil { + return err + } + list = append(list, 10, 19, 17, 15) + w.Encode(&list) + fmt.Println(m.ConnectID(), "OneEvent", r.EventID()) + return nil +} + +func (v *Controller) Timer() { + t := time.NewTicker(time.Second * 3) + defer t.Stop() + + for { + select { + case tt := <-t.C: + v.muxRLock(func() { + for cid := range v.list { + v.sender.SendEvent(12, tt.Format(time.RFC3339), cid) + fmt.Println("Timer", cid) + } + }) + v.sender.Broadcast(99, tt.Unix()) + } + } +} + +func (v *Controller) MultiEvent(w websocket.Response, r websocket.Request, m websocket.Meta) error { + switch r.EventID() { + case 11: + v.muxLock(func() { + v.list[m.ConnectID()] = struct{}{} + fmt.Println("MultiEvent Add", m.ConnectID()) + }) + + m.OnClose(func(cid string) { + v.muxLock(func() { + delete(v.list, cid) + fmt.Println("MultiEvent Close", cid) + }) + }) + + case 13: + v.muxLock(func() { + delete(v.list, m.ConnectID()) + fmt.Println("MultiEvent Del", m.ConnectID()) + }) + + } + return nil +} + +func (v *Controller) muxLock(cb func()) { + v.mux.Lock() + cb() + v.mux.Unlock() +} + +func (v *Controller) muxRLock(cb func()) { + v.mux.RLock() + cb() + v.mux.RUnlock() +} diff --git a/examples/sdk/demo-app1/config.yaml b/examples/sdk/demo-app1/config.yaml new file mode 100644 index 0000000..6488b96 --- /dev/null +++ b/examples/sdk/demo-app1/config.yaml @@ -0,0 +1,4 @@ +env: dev +level: 4 +log: /dev/stdout +pig: /tmp/simple.pid diff --git a/examples/sdk/demo-app1/main.go b/examples/sdk/demo-app1/main.go new file mode 100644 index 0000000..8405571 --- /dev/null +++ b/examples/sdk/demo-app1/main.go @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "fmt" + + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/log" +) + +type ( + //Simple model + Simple struct{} + //Config model + Config1 struct { + Env string `yaml:"env"` + } + Config2 struct { + Env string `yaml:"env"` + } +) + +// NewSimple init Simple +func NewSimple(c1 Config1, c2 Config2) *Simple { + fmt.Println("--> call NewSimple") + fmt.Println("--> Config1.ENV=" + c1.Env) + fmt.Println("--> Config2.ENV=" + c2.Env) + return &Simple{} +} + +// Up method for start Simple in DI container +func (s *Simple) Up(_ app.Context) error { + fmt.Println("--> call *Simple.Up") + return nil +} + +// Down method for stop Simple in DI container +func (s *Simple) Down(_ app.Context) error { + fmt.Println("--> call *Simple.Down") + return nil +} + +func main() { + app.New(). + Logger(log.Default()). + ConfigFile( + "./config.yaml", + Config1{}, + ). + Modules( + Config2{Env: "prod"}, + NewSimple, + ). + Run() +} diff --git a/examples/sdk/demo-app2/config.yaml b/examples/sdk/demo-app2/config.yaml new file mode 100644 index 0000000..27ff22b --- /dev/null +++ b/examples/sdk/demo-app2/config.yaml @@ -0,0 +1,7 @@ +env: dev +level: 4 +log: /dev/stdout +pig: /tmp/simple.pid + +http: + aaa: 000 \ No newline at end of file diff --git a/examples/sdk/demo-app2/main.go b/examples/sdk/demo-app2/main.go new file mode 100644 index 0000000..33e14a1 --- /dev/null +++ b/examples/sdk/demo-app2/main.go @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "fmt" + + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/log" +) + +type ( + Test0 struct{} + Test1 struct{} + Test2 struct{} + + Config struct { + Env string `yaml:"env"` + Level string `yaml:"level"` + } + + Params struct { + Test1 *Test1 + Config Config + } +) + +func (s *Test2) Up() error { + fmt.Println("--> call *Test2.Up") + return nil +} + +func (s *Test2) Down() error { + fmt.Println("--> call *Test2.Down") + return nil +} + +func NewTest0(p Params) *Test0 { + fmt.Println("--> call NewTest0") + fmt.Println("--> Params.Config.Env=" + p.Config.Env) + return &Test0{} +} + +func NewTest2(_ *Test0) *Test2 { + fmt.Println("--> call NewTest2") + return &Test2{} +} + +func main() { + app.New(). + Logger(log.Default()). + ConfigFile( + "./config.yaml", + Config{}, + ). + Modules( + &Test1{}, + NewTest0, + NewTest2, + ). + Run() +} diff --git a/examples/sdk/demo-cli-app1/main.go b/examples/sdk/demo-cli-app1/main.go new file mode 100644 index 0000000..3d4fdbd --- /dev/null +++ b/examples/sdk/demo-cli-app1/main.go @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "fmt" + "strings" + + "github.com/osspkg/goppy/sdk/console" +) + +func main() { + root := console.New("tool", "help tool") + + simpleCmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("simple", "third level") + setter.Example("simple aa/bb/cc -a=hello -b=123 --cc=123.456 -e") + + setter.Flag(func(f console.FlagsSetter) { + f.StringVar("a", "demo", "this is a string argument") + f.IntVar("b", 1, "this is a int64 argument") + f.FloatVar("cc", 1e-5, "this is a float64 argument") + f.Bool("e", "this is a bool argument") + }) + + setter.ArgumentFunc(func(s []string) ([]string, error) { + if !strings.Contains(s[0], "/") { + return nil, fmt.Errorf("argument must contain /") + } + return strings.Split(s[0], "/"), nil + }) + + setter.ExecFunc(func(args []string, a string, b int64, c float64, d bool) { + fmt.Println(args, a, b, c, d) + }) + }) + + twoCmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("two", "second level") + + setter.AddCommand(simpleCmd) + }) + + oneCmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("one", "first level") + + setter.AddCommand(twoCmd) + }) + + root.AddCommand(oneCmd) + root.Exec() +} diff --git a/examples/sdk/demo-cli-app2/main.go b/examples/sdk/demo-cli-app2/main.go new file mode 100644 index 0000000..aa4e2e6 --- /dev/null +++ b/examples/sdk/demo-cli-app2/main.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "fmt" + "strings" + + "github.com/osspkg/goppy/sdk/console" +) + +func main() { + root := console.New("tool", "help tool") + + cmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("simple", "first-level command") + setter.Example("simple aa/bb/cc -a=hello -b=123 --cc=123.456 -e") + + setter.Flag(func(f console.FlagsSetter) { + f.StringVar("a", "demo", "this is a string argument") + f.IntVar("b", 1, "this is a int64 argument") + f.FloatVar("cc", 1e-5, "this is a float64 argument") + f.Bool("e", "this is a bool argument") + }) + + setter.ArgumentFunc(func(s []string) ([]string, error) { + if !strings.Contains(s[0], "/") { + return nil, fmt.Errorf("argument must contain `/`") + } + return strings.Split(s[0], "/"), nil + }) + + setter.ExecFunc(func(args []string, a string, b int64, c float64, d bool) { + fmt.Println(args, a, b, c, d) + }) + }) + + root.AddCommand(cmd) + root.Exec() +} diff --git a/examples/sdk/demo-cli-app3/main.go b/examples/sdk/demo-cli-app3/main.go new file mode 100644 index 0000000..e7df3b4 --- /dev/null +++ b/examples/sdk/demo-cli-app3/main.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "fmt" + "strings" + + "github.com/osspkg/goppy/sdk/console" +) + +func main() { + console.ShowDebug(true) + + app := console.New("tool", "help tool") + + cmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("a", "command a") + setter.ExecFunc(func(args []string) { + fmt.Println("a", args) + }) + + setter.AddCommand(console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("b", "command b") + setter.ExecFunc(func(args []string) { + fmt.Println("b", args) + }) + })) + + }) + + root := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("root", "command root") + setter.Flag(func(setter console.FlagsSetter) { + setter.Bool("aaa", "bool a") + }) + setter.ArgumentFunc(func(s []string) ([]string, error) { + return []string{strings.Join(s, "-")}, nil + }) + setter.ExecFunc(func(args []string, a bool) { + fmt.Println("root", args, a) + }) + }) + + app.RootCommand(root) + app.AddCommand(cmd) + app.Exec() +} diff --git a/examples/sdk/demo-oauth1/main.go b/examples/sdk/demo-oauth1/main.go new file mode 100644 index 0000000..aedeaab --- /dev/null +++ b/examples/sdk/demo-oauth1/main.go @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "fmt" + "net/http" + "time" + + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/auth/oauth" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/webutil" +) + +var ( + provConf = &oauth.Config{ + Provider: []oauth.ConfigItem{ + { + Code: "google", + ClientID: "****************.apps.googleusercontent.com", + ClientSecret: "****************", + RedirectURL: "https://example.com/oauth/callback/google", + }, + }, + } + + servConf = webutil.ConfigHttp{Addr: ":8080"} +) + +func main() { + ctx := app.NewContext() + authServ := oauth.New(provConf) + + route := webutil.NewRouter() + route.Route("/oauth/request/google", authServ.Request(oauth.CodeGoogle), http.MethodGet) + route.Route("/oauth/callback/google", authServ.CallBack(oauth.CodeGoogle, oauthCallBackHandler), http.MethodGet) + + serv := webutil.NewServerHttp(servConf, route, log.Default()) + serv.Up(ctx) //nolint: errcheck + <-time.After(60 * time.Minute) + ctx.Close() + serv.Down() //nolint: errcheck +} + +const out = ` +email: %s +name: %s +ico: %s +` + +func oauthCallBackHandler(w http.ResponseWriter, _ *http.Request, u oauth.User) { + w.WriteHeader(200) + fmt.Fprintf(w, out, u.GetEmail(), u.GetName(), u.GetIcon()) +} diff --git a/examples/sdk/demo-ws-cli-observable/main.go b/examples/sdk/demo-ws-cli-observable/main.go new file mode 100644 index 0000000..3b48ca0 --- /dev/null +++ b/examples/sdk/demo-ws-cli-observable/main.go @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package main + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/osspkg/goppy/sdk/syscall" + + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" + + "github.com/osspkg/goppy/sdk/netutil/websocket" +) + +func main() { + group := iosync.NewGroup() + ctx, cncl := context.WithCancel(context.TODO()) + cli := websocket.NewClient(ctx, "ws://127.0.0.1:8088/ws", log.Default()) + defer cli.Close() + go syscall.OnStop(func() { + cli.Close() + }) + group.Background(func() { + err := cli.DialAndListen() + if err != nil { + log.WithError("err", err).Errorf("ws dial") + } + }) + <-time.After(100 * time.Millisecond) + + obs := websocket.NewObservable(cli) + + obs.Subscribe(1, []int{0}). + Listen(func(arg websocket.ListenArg) { + var vv json.RawMessage + if err := arg.Decode(&vv); err != nil { + fmt.Println(err) + } + fmt.Println(string(vv)) + }, + websocket.PipeTake(1), + websocket.PipeTimeout(1*time.Second), + ) + + obs.Subscribe(99, nil). + Listen(func(arg websocket.ListenArg) { + var vv json.RawMessage + if err := arg.Decode(&vv); err != nil { + fmt.Println(err) + } + fmt.Println(string(vv)) + }, + websocket.PipeTake(3), + ) + + cncl() + group.Wait() +} diff --git a/go.mod b/go.mod index f96e88f..583b252 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,18 @@ module github.com/osspkg/goppy go 1.18 require ( + github.com/go-sql-driver/mysql v1.7.1 github.com/gorilla/websocket v1.5.0 + github.com/lib/pq v1.10.9 github.com/mailru/easyjson v0.7.7 + github.com/mattn/go-sqlite3 v1.14.17 github.com/oschwald/geoip2-golang v1.9.0 - github.com/osspkg/go-sdk v1.3.6 - github.com/osspkg/go-static v1.3.2 + github.com/osspkg/go-algorithms v1.2.6 + github.com/osspkg/go-static v1.3.3 github.com/stretchr/testify v1.8.4 + golang.org/x/crypto v0.14.0 + golang.org/x/oauth2 v0.13.0 + golang.org/x/sys v0.13.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -16,17 +22,13 @@ require ( cloud.google.com/go/compute v1.20.1 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.7.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/josharian/intern v1.0.0 // indirect - github.com/lib/pq v1.10.9 // indirect - github.com/mattn/go-sqlite3 v1.14.17 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/oschwald/maxminddb-golang v1.11.0 // indirect - github.com/osspkg/go-algorithms v1.2.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/net v0.12.0 // indirect - golang.org/x/oauth2 v0.10.0 // indirect - golang.org/x/sys v0.10.0 // indirect + golang.org/x/net v0.17.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.31.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/go.sum b/go.sum index d75732d..f37770a 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,7 @@ cloud.google.com/go/compute v1.20.1 h1:6aKEtlUiwEpJzM001l0yFkpXmUVXaN8W+fbkb2AZN cloud.google.com/go/compute v1.20.1/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= @@ -16,6 +17,10 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -28,23 +33,26 @@ github.com/oschwald/maxminddb-golang v1.11.0 h1:aSXMqYR/EPNjGE8epgqwDay+P30hCBZI github.com/oschwald/maxminddb-golang v1.11.0/go.mod h1:YmVI+H0zh3ySFR3w+oz8PCfglAFj3PuCmui13+P9zDg= github.com/osspkg/go-algorithms v1.2.6 h1:/eIZ1XlxZ2LRxtbOG4REA5KcR17F5E64+Y2cudJvfRU= github.com/osspkg/go-algorithms v1.2.6/go.mod h1:Zdclm/CKhDrUD34kIm9PL4VDbiey/jXAHY36nfc0r5Q= -github.com/osspkg/go-sdk v1.3.6 h1:VcB5o3+c1uwkD2B0v4ZxOeaN/SN0QUr3Fktt8S3Le0M= -github.com/osspkg/go-sdk v1.3.6/go.mod h1:/ZzPlvttlMyTLiDomDGJt7sqFz/RmkWXzrK5tifci14= -github.com/osspkg/go-static v1.3.2 h1:MST9eSG/gO6b1OjFlxoFfSLhdKRiaUA8lEUhh0KoUOg= -github.com/osspkg/go-static v1.3.2/go.mod h1:PI0nuemvgmLuVgn3hIZMRFgQH6d1Iyn70Gs23Gl05qo= +github.com/osspkg/go-static v1.3.3 h1:jqwE6zoucvgAAuh7MJ2spj2INhZeotq+LRlQzg1B5sg= +github.com/osspkg/go-static v1.3.3/go.mod h1:PI0nuemvgmLuVgn3hIZMRFgQH6d1Iyn70Gs23Gl05qo= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= -golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= -golang.org/x/oauth2 v0.10.0 h1:zHCpF2Khkwy4mMB4bv0U37YtJdTGW8jI0glAApi0Kh8= -golang.org/x/oauth2 v0.10.0/go.mod h1:kTpgurOux7LqtuxjuyZa4Gj2gdezIt/jQtGnNFfypQI= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/oauth2 v0.13.0 h1:jDDenyj+WgFtmV3zYVoi8aE2BwtXFLWOA67ZfNWftiY= +golang.org/x/oauth2 v0.13.0/go.mod h1:/JMhi4ZRXAf4HG9LiNmxvk+45+96RUlVThiH8FzNBn0= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= -golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -55,7 +63,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/goppy.go b/goppy.go index 2bcfe18..40dadcf 100644 --- a/goppy.go +++ b/goppy.go @@ -10,10 +10,10 @@ import ( "os" "reflect" - "github.com/osspkg/go-sdk/app" - "github.com/osspkg/go-sdk/console" - "github.com/osspkg/go-sdk/errors" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/console" + "github.com/osspkg/goppy/sdk/errors" "gopkg.in/yaml.v3" ) diff --git a/plugins/auth/jwt.go b/plugins/auth/jwt.go index 87f5bff..cb81548 100644 --- a/plugins/auth/jwt.go +++ b/plugins/auth/jwt.go @@ -9,9 +9,9 @@ import ( "fmt" "time" - "github.com/osspkg/go-sdk/auth/jwt" - "github.com/osspkg/go-sdk/random" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/auth/jwt" + "github.com/osspkg/goppy/sdk/random" ) type ( diff --git a/plugins/auth/middleware.go b/plugins/auth/middleware.go index adbf054..b971b0e 100644 --- a/plugins/auth/middleware.go +++ b/plugins/auth/middleware.go @@ -12,8 +12,8 @@ import ( "net/http" "strings" - "github.com/osspkg/go-sdk/auth/jwt" "github.com/osspkg/goppy/plugins/web" + "github.com/osspkg/goppy/sdk/auth/jwt" ) const ( diff --git a/plugins/auth/oauth.go b/plugins/auth/oauth.go index 5e94c54..ca86ed6 100644 --- a/plugins/auth/oauth.go +++ b/plugins/auth/oauth.go @@ -8,9 +8,9 @@ package auth import ( "net/http" - "github.com/osspkg/go-sdk/auth/oauth" "github.com/osspkg/goppy/plugins" "github.com/osspkg/goppy/plugins/web" + "github.com/osspkg/goppy/sdk/auth/oauth" ) // ConfigOAuth oauth config model diff --git a/plugins/database/migrator.go b/plugins/database/migrator.go index 01ce0de..d87a64a 100644 --- a/plugins/database/migrator.go +++ b/plugins/database/migrator.go @@ -13,12 +13,12 @@ import ( "sort" "time" - "github.com/osspkg/go-sdk/app" - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/iofile" - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/orm" - "github.com/osspkg/go-sdk/orm/schema" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/iofile" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/orm" + "github.com/osspkg/goppy/sdk/orm/schema" ) type ( diff --git a/plugins/database/mysql.go b/plugins/database/mysql.go index 57123c3..67db7b3 100644 --- a/plugins/database/mysql.go +++ b/plugins/database/mysql.go @@ -11,14 +11,14 @@ import ( "sync" "time" - "github.com/osspkg/go-sdk/app" - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/orm" - "github.com/osspkg/go-sdk/orm/schema" - "github.com/osspkg/go-sdk/orm/schema/mysql" - "github.com/osspkg/go-sdk/routine" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/orm" + "github.com/osspkg/goppy/sdk/orm/schema" + "github.com/osspkg/goppy/sdk/orm/schema/mysql" + "github.com/osspkg/goppy/sdk/routine" ) // ConfigMysql mysql config model diff --git a/plugins/database/pgsql.go b/plugins/database/pgsql.go index 30c12d7..d7d15ba 100644 --- a/plugins/database/pgsql.go +++ b/plugins/database/pgsql.go @@ -11,14 +11,14 @@ import ( "sync" "time" - "github.com/osspkg/go-sdk/app" - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/orm" - "github.com/osspkg/go-sdk/orm/schema" - "github.com/osspkg/go-sdk/orm/schema/postgresql" - "github.com/osspkg/go-sdk/routine" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/orm" + "github.com/osspkg/goppy/sdk/orm/schema" + "github.com/osspkg/goppy/sdk/orm/schema/postgresql" + "github.com/osspkg/goppy/sdk/routine" ) // ConfigPgsql pgsql config model diff --git a/plugins/database/sqlite.go b/plugins/database/sqlite.go index 9893075..3f4b455 100644 --- a/plugins/database/sqlite.go +++ b/plugins/database/sqlite.go @@ -11,15 +11,15 @@ import ( "sync" "time" - "github.com/osspkg/go-sdk/app" - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/iofile" - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/orm" - "github.com/osspkg/go-sdk/orm/schema" - "github.com/osspkg/go-sdk/orm/schema/sqlite" - "github.com/osspkg/go-sdk/routine" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/iofile" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/orm" + "github.com/osspkg/goppy/sdk/orm/schema" + "github.com/osspkg/goppy/sdk/orm/schema/sqlite" + "github.com/osspkg/goppy/sdk/routine" ) // ConfigSqlite sqlite config model diff --git a/plugins/unix/client.go b/plugins/unix/client.go index 515e041..4df8375 100644 --- a/plugins/unix/client.go +++ b/plugins/unix/client.go @@ -6,24 +6,23 @@ package unix import ( - "net" "sync" - "github.com/osspkg/go-sdk/errors" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/netutil/unixsocket" ) func WithClient() plugins.Plugin { return plugins.Plugin{ - Inject: func() (*cliProvider, Client) { - s := newCliProvider() + Inject: func() (*clientProvider, Client) { + s := newClientProvider() return s, s }, } } type ( - cliProvider struct { + clientProvider struct { list map[string]ClientConnect mux sync.RWMutex } @@ -31,56 +30,26 @@ type ( Client interface { Create(path string) (ClientConnect, error) } + + ClientConnect interface { + Exec(name string, b []byte) ([]byte, error) + ExecString(name string, b string) ([]byte, error) + } ) -func newCliProvider() *cliProvider { - return &cliProvider{ +func newClientProvider() *clientProvider { + return &clientProvider{ list: make(map[string]ClientConnect), } } -func (v *cliProvider) Create(path string) (ClientConnect, error) { +func (v *clientProvider) Create(path string) (ClientConnect, error) { v.mux.Lock() defer v.mux.Unlock() if c, ok := v.list[path]; ok { return c, nil } - c := newClient(path) + c := unixsocket.NewClient(path) v.list[path] = c return c, nil } - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -type ( - cli struct { - path string - } - - ClientConnect interface { - Exec(name string, b []byte) ([]byte, error) - ExecString(name string, b string) ([]byte, error) - } -) - -func newClient(path string) *cli { - return &cli{ - path: path, - } -} - -func (v *cli) Exec(name string, b []byte) ([]byte, error) { - conn, err := net.Dial("unix", v.path) - if err != nil { - return nil, errors.Wrapf(err, "open connect [unix:%s]", v.path) - } - defer conn.Close() //nolint: errcheck - if err = writeBytes(conn, append([]byte(name+cmddelimstring), b...)); err != nil { - return nil, err - } - return readBytes(conn) -} - -func (v *cli) ExecString(name string, b string) ([]byte, error) { - return v.Exec(name, []byte(b)) -} diff --git a/plugins/unix/common.go b/plugins/unix/common.go deleted file mode 100644 index d00e0cb..0000000 --- a/plugins/unix/common.go +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. - * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. - */ - -package unix - -import ( - "bytes" - "io" - - "github.com/osspkg/go-sdk/errors" -) - -var ( - delimstring = "\n" - delimbyte = []byte(delimstring) - delimlen = len(delimbyte) - - cmddelimstring = " " - cmddelim = byte(' ') - - errInvalidCommand = errors.New("command not found") -) - -func readBytes(v io.Reader) ([]byte, error) { - var ( - n int - err error - b = make([]byte, 0, 512) - ) - - for { - if len(b) == cap(b) { - b = append(b, 0)[:len(b)] - } - n, err = v.Read(b[len(b):cap(b)]) - b = b[:len(b)+n] - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - if len(b) < delimlen { - return b, io.EOF - } - if bytes.Equal(delimbyte, b[len(b)-delimlen:]) { - b = b[:len(b)-delimlen] - break - } - } - return b, nil -} - -func writeBytes(v io.Writer, b []byte) error { - if len(b) < delimlen || !bytes.Equal(delimbyte, b[len(b)-delimlen:]) { - b = append(b, delimbyte...) - } - if _, err := v.Write(b); err != nil { - return err - } - return nil -} - -func writeError(v io.Writer, err error) error { - return writeBytes(v, []byte(err.Error())) -} - -func parse(b []byte) (string, []byte) { - for i := 0; i < len(b); i++ { - if b[i] == cmddelim { - if len(b) > i+2 { - return string(b[0:i]), b[i+1:] - } - return string(b[0:i]), nil - } - } - return string(b), nil -} diff --git a/plugins/unix/server.go b/plugins/unix/server.go index ac03af0..712525a 100644 --- a/plugins/unix/server.go +++ b/plugins/unix/server.go @@ -6,15 +6,11 @@ package unix import ( - "io" - "net" - "os" - "sync" - "time" - - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/log" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/netutil/unixsocket" ) type ( @@ -30,118 +26,60 @@ func (v *Config) Default() { func WithServer() plugins.Plugin { return plugins.Plugin{ Config: &Config{}, - Inject: func(c *Config, l log.Logger) (*srv, Server) { - s := newServer(c, l) + Inject: func(c *Config, l log.Logger) (*serverProvider, Server) { + s := newServerProvider(c, l) return s, s }, } } type ( - srv struct { - config *Config - sock net.Listener - log log.Logger - commands map[string]Handler - mux sync.RWMutex + serverProvider struct { + config *Config + serv *unixsocket.Server + wg iosync.Group + log log.Logger } - //Handler unix socket command handler - Handler func([]byte) ([]byte, error) - Server interface { - Command(name string, h Handler) + Command(name string, h func([]byte) ([]byte, error)) } ) -func newServer(c *Config, l log.Logger) *srv { - return &srv{ - config: c, - log: l, - commands: make(map[string]Handler), - } -} - -func (v *srv) Up() (err error) { - if err = os.Remove(v.config.Path); err != nil && !os.IsNotExist(err) { - err = errors.Wrapf(err, "remove unix socket [unix:%s]", v.config.Path) - return - } - if v.sock, err = net.Listen("unix", v.config.Path); err != nil { - err = errors.Wrapf(err, "init unix socket [unix:%s]", v.config.Path) - return +func newServerProvider(c *Config, l log.Logger) *serverProvider { + return &serverProvider{ + config: c, + log: l, + serv: unixsocket.NewServer(c.Path), } - - go v.accept() - return } -func (v *srv) Down() error { - if v.sock != nil { - return v.sock.Close() - } - return nil -} - -func (v *srv) Command(name string, h Handler) { - v.mux.Lock() - v.commands[name] = h - v.mux.Unlock() -} - -func (v *srv) logError(err error, msg string) { - if err == nil { - return - } - - v.log.WithFields(log.Fields{ - "err": err.Error(), - }).Errorf(msg) -} - -func (v *srv) accept() { - for { - fd, err := v.sock.Accept() - if err != nil { - v.logError(err, "accept unix socket") - return - } - if err = fd.SetDeadline(time.Now().Add(time.Hour)); err != nil { - v.logError(err, "unix socket set deadline") +func (v *serverProvider) Up(ctx app.Context) (err error) { + v.serv.ErrorLog(func(err error) { + v.log.WithError("err", err).Errorf("unix") + }) + v.wg.Background(func() { + if err := v.serv.Up(); err != nil { + v.log.WithFields(log.Fields{ + "err": err.Error(), + "path": v.config.Path, + }).Errorf("Unix server stopped") + ctx.Close() return } - go v.pump(fd) - } + v.log.WithFields(log.Fields{ + "path": v.config.Path, + }).Infof("Unix server stopped") + }) + return } -func (v *srv) pump(rw io.ReadWriteCloser) { - defer func() { - if err := rw.Close(); err != nil { - v.logError(err, "close unix socket request") - } - }() - - b, err := readBytes(rw) - if err != nil { - v.logError(err, "read unix socket request") - v.logError(writeError(rw, err), "write unix socket error") - return - } - - cmd, data := parse(b) - - v.mux.RLock() - h, ok := v.commands[cmd] - v.mux.RUnlock() - if !ok { - v.logError(writeError(rw, errInvalidCommand), "write unix socket error") - return - } +func (v *serverProvider) Down() error { + err := v.serv.Down() + v.wg.Wait() + return err +} - out, err := h(data) - if err != nil { - v.logError(writeError(rw, err), "write unix socket error") - return - } - v.logError(writeBytes(rw, out), "write unix socket response") +func (v *serverProvider) Command(name string, h func([]byte) ([]byte, error)) { + v.serv.AddCommand(name, h) } diff --git a/plugins/web/common.go b/plugins/web/common.go index c4ab400..e9ba2c6 100644 --- a/plugins/web/common.go +++ b/plugins/web/common.go @@ -4,21 +4,3 @@ */ package web - -type rwlocker interface { - RLock() - RUnlock() - Lock() - Unlock() -} - -func lock(l rwlocker, call func()) { - l.Lock() - call() - l.Unlock() -} -func rwlock(l rwlocker, call func()) { - l.RLock() - call() - l.RUnlock() -} diff --git a/plugins/web/debug_server.go b/plugins/web/debug_server.go index 2b9bf19..82666ac 100644 --- a/plugins/web/debug_server.go +++ b/plugins/web/debug_server.go @@ -6,9 +6,9 @@ package web import ( - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/webutil" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/webutil" ) // ConfigDebug config to initialize HTTP debug service diff --git a/plugins/web/http_client.go b/plugins/web/http_client.go index aac0cd2..974e13a 100644 --- a/plugins/web/http_client.go +++ b/plugins/web/http_client.go @@ -6,8 +6,8 @@ package web import ( - "github.com/osspkg/go-sdk/webutil" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/webutil" ) // WithHTTPClient init pool http clients diff --git a/plugins/web/http_server.go b/plugins/web/http_server.go index 95213fe..ce33b38 100644 --- a/plugins/web/http_server.go +++ b/plugins/web/http_server.go @@ -6,9 +6,9 @@ package web import ( - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/webutil" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/webutil" ) // ConfigHttp config to initialize HTTP service diff --git a/plugins/web/http_server_context.go b/plugins/web/http_server_context.go index 5ce3e1d..3b07ac0 100644 --- a/plugins/web/http_server_context.go +++ b/plugins/web/http_server_context.go @@ -15,10 +15,10 @@ import ( "net/url" "strconv" - "github.com/osspkg/go-sdk/ioutil" - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/webutil" "github.com/osspkg/go-static" + "github.com/osspkg/goppy/sdk/ioutil" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/webutil" ) type ( diff --git a/plugins/web/http_server_context_easyjson.go b/plugins/web/http_server_context_easyjson.go index 835d964..ce70c81 100644 --- a/plugins/web/http_server_context_easyjson.go +++ b/plugins/web/http_server_context_easyjson.go @@ -1,3 +1,8 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + // Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. package web diff --git a/plugins/web/http_server_router.go b/plugins/web/http_server_router.go index e9527ea..e34b570 100644 --- a/plugins/web/http_server_router.go +++ b/plugins/web/http_server_router.go @@ -12,9 +12,9 @@ import ( "net/http" "strings" - "github.com/osspkg/go-sdk/app" - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/webutil" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/webutil" ) type ( diff --git a/plugins/web/http_server_router_easyjson.go b/plugins/web/http_server_router_easyjson.go index a3b4294..60add82 100644 --- a/plugins/web/http_server_router_easyjson.go +++ b/plugins/web/http_server_router_easyjson.go @@ -1,3 +1,8 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + // Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. package web diff --git a/plugins/web/http_server_router_test.go b/plugins/web/http_server_router_test.go index d4f9ef6..0ec14e6 100644 --- a/plugins/web/http_server_router_test.go +++ b/plugins/web/http_server_router_test.go @@ -13,9 +13,9 @@ import ( "net/http/httptest" "testing" - "github.com/osspkg/go-sdk/ioutil" - "github.com/osspkg/go-sdk/log" - "github.com/osspkg/go-sdk/webutil" + "github.com/osspkg/goppy/sdk/ioutil" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/webutil" "github.com/stretchr/testify/require" ) diff --git a/plugins/web/ws_client.go b/plugins/web/ws_client.go index 3dcebbe..2de2b9b 100644 --- a/plugins/web/ws_client.go +++ b/plugins/web/ws_client.go @@ -7,29 +7,18 @@ package web import ( "context" - "encoding/json" - "fmt" - "net/http" - "sync" - "sync/atomic" - "github.com/gorilla/websocket" - context2 "github.com/osspkg/go-sdk/context" - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/log" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/netutil/websocket" ) func WithWebsocketClient() plugins.Plugin { return plugins.Plugin{ Inject: func(l log.Logger) (*wscProvider, WebsocketClient) { - ctx, cncl := context.WithCancel(context.Background()) - c := &wscProvider{ - connects: make(map[string]WebsocketClientConn), - log: l, - ctx: ctx, - cancel: cncl, - } + c := newWSClientProvider(l) return c, c }, } @@ -37,48 +26,74 @@ func WithWebsocketClient() plugins.Plugin { type ( wscProvider struct { - connects map[string]WebsocketClientConn + connects map[string]websocket.Client cancel context.CancelFunc ctx context.Context - mux sync.RWMutex - wg sync.WaitGroup + sync iosync.Switch + mux iosync.Lock + wg iosync.Group log log.Logger } WebsocketClient interface { - Create(ctx context.Context, url string, opts ...func(WebsocketClientOption)) (WebsocketClientConn, error) + Create(url string, opts ...func(websocket.ClientOption)) WebsocketClientConnect + } + + WebsocketClientConnect interface { + Encode(eid websocket.EventID, in interface{}) + ConnectID() string + Header(key, value string) + SetHandler(call websocket.ClientHandler, eids ...websocket.EventID) + DelHandler(eids ...websocket.EventID) + OnClose(cb func(cid string)) + OnOpen(cb func(cid string)) + Close() } ) -func (v *wscProvider) Up() error { +func newWSClientProvider(l log.Logger) *wscProvider { + return &wscProvider{ + connects: make(map[string]websocket.Client, 2), + sync: iosync.NewSwitch(), + log: l, + mux: iosync.NewLock(), + wg: iosync.NewGroup(), + } +} + +func (v *wscProvider) Up(ctx app.Context) error { + if v.sync.On() { + v.ctx, v.cancel = context.WithCancel(ctx.Context()) + } return nil } func (v *wscProvider) Down() error { + if !v.sync.Off() { + return nil + } v.cancel() v.wg.Wait() return nil } -func (v *wscProvider) addConn(cc WebsocketClientConn) { - v.wg.Add(1) - lock(&v.mux, func() { +func (v *wscProvider) addConn(cc websocket.Client) { + v.mux.Lock(func() { v.connects[cc.ConnectID()] = cc }) } func (v *wscProvider) delConn(cid string) { - lock(&v.mux, func() { + v.mux.Lock(func() { delete(v.connects, cid) }) - v.wg.Done() } func (v *wscProvider) errLog(cid string, err error, msg string, args ...interface{}) { - if err == nil { + if err == nil || v.log == nil { return } v.log.WithFields(log.Fields{ @@ -87,11 +102,8 @@ func (v *wscProvider) errLog(cid string, err error, msg string, args ...interfac }).Errorf(msg, args...) } -func (v *wscProvider) Create( - ctx context.Context, url string, - opts ...func(WebsocketClientOption), -) (WebsocketClientConn, error) { - cc := newWSCConnect(url, v.errLog, ctx, v.ctx, opts) +func (v *wscProvider) Create(url string, opts ...func(websocket.ClientOption)) WebsocketClientConnect { + cc := websocket.NewClient(v.ctx, url, v.log, opts...) cc.OnClose(func(cid string) { v.delConn(cid) @@ -100,254 +112,10 @@ func (v *wscProvider) Create( v.addConn(cc) }) - return cc, nil -} - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -type ( - wscConn struct { - status int64 - cid string - - url string - headers http.Header - conn *websocket.Conn - - sendC chan []byte - events map[uint]WebsocketClientHandler - - ctx context.Context - cancel context.CancelFunc - - onOpen, onClose []func(cid string) - erw func(cid string, err error, msg string, args ...interface{}) - - cm sync.RWMutex - em sync.RWMutex - } - - WebsocketClientOption interface { - Header(key string, value string) - } - - WebsocketClientHandler func(d WebsocketEventer, c WebsocketClientProcessor) error - - WebsocketClientProcessor interface { - ConnectID() string - OnClose(cb func(cid string)) - Encode(eventID uint, in interface{}) - EncodeEvent(event WebsocketEventer, in interface{}) - } - - WebsocketClientConn interface { - ConnectID() string - Event(call WebsocketClientHandler, eid ...uint) - Encode(id uint, in interface{}) - Close() - Run() error - } -) - -func newWSCConnect( - url string, - erw func(cid string, err error, msg string, args ...interface{}), - ctx1, ctx2 context.Context, - opts []func(WebsocketClientOption), -) *wscConn { - ctx, cancel := context2.Combine(ctx1, ctx2) - cc := &wscConn{ - status: off, - url: url, - headers: make(http.Header), - sendC: make(chan []byte, 128), - events: make(map[uint]WebsocketClientHandler, 128), - ctx: ctx, - cancel: cancel, - onClose: make([]func(string), 0), - erw: erw, - } - - for _, opt := range opts { - opt(cc) - } - - return cc -} - -func (v *wscConn) ConnectID() string { - return v.cid -} - -func (v *wscConn) connect() *websocket.Conn { - return v.conn -} - -func (v *wscConn) cancelFunc() context.CancelFunc { - return v.cancel -} - -func (v *wscConn) done() <-chan struct{} { - return v.ctx.Done() -} - -func (v *wscConn) errLog(cid string, err error, msg string, args ...interface{}) { - v.erw(cid, err, msg, args...) -} - -func (v *wscConn) OnClose(cb func(cid string)) { - lock(&v.cm, func() { - v.onClose = append(v.onClose, cb) - }) -} - -func (v *wscConn) OnOpen(cb func(cid string)) { - lock(&v.cm, func() { - v.onOpen = append(v.onOpen, cb) - }) -} - -func (v *wscConn) Header(key string, value string) { - lock(&v.cm, func() { - v.headers.Set(key, value) - }) -} - -func (v *wscConn) Event(call WebsocketClientHandler, eid ...uint) { - lock(&v.em, func() { - for _, i := range eid { - v.events[i] = call - } - }) -} - -func (v *wscConn) getEventHandler(id uint) (h WebsocketClientHandler, ok bool) { - rwlock(&v.em, func() { - h, ok = v.events[id] - }) - return -} - -func (v *wscConn) Write(b []byte) { - if len(b) == 0 { - return - } - - select { - case v.sendC <- b: - default: - } -} - -func (v *wscConn) dataBus() <-chan []byte { - return v.sendC -} - -func (v *wscConn) Encode(eventID uint, in interface{}) { - eventModel(func(ev *event) { - ev.ID = eventID - ev.Encode(in) - b, err := json.Marshal(ev) - if err != nil { - v.errLog(v.ConnectID(), err, "[ws] encode message: %d", eventID) - return - } - v.Write(b) - }) -} - -func (v *wscConn) EncodeEvent(e WebsocketEventer, in interface{}) { - eventModel(func(ev *event) { - ev.ID = e.EventID() - ev.UID = e.UniqueID() - ev.Encode(in) - b, err := json.Marshal(ev) - if err != nil { - v.errLog(v.ConnectID(), err, "[ws] encode message: %d", e.EventID()) - return - } - v.Write(b) - }) -} - -func (v *wscConn) dataHandler(b []byte) { - eventModel(func(ev *event) { - var ( - err error - msg string - ) - defer func() { - if err != nil { - v.errLog(v.ConnectID(), err, "[ws] "+msg) - } - }() - if err = json.Unmarshal(b, ev); err != nil { - msg = "decode message" - return - } - call, ok := v.getEventHandler(ev.EventID()) - if !ok { - return - } - err = call(ev, v) - if err != nil { - ev.Error(err) - bb, er := json.Marshal(ev) - if er != nil { - msg = fmt.Sprintf("[ws] call event handler: %d", ev.EventID()) - err = errors.Wrap(err, er) - return - } - err = nil - v.Write(bb) - return + v.wg.Background(func() { + if err := cc.DialAndListen(); err != nil { + v.errLog(cc.ConnectID(), err, "[ws] dial to %s", url) } }) -} - -func (v *wscConn) Close() { - if !atomic.CompareAndSwapInt64(&v.status, on, down) { - return - } - v.cancel() -} - -func (v *wscConn) Run() (err error) { - if !atomic.CompareAndSwapInt64(&v.status, off, on) { - return errOneOpenConnect - } - - var resp *http.Response - - if v.conn, resp, err = websocket.DefaultDialer.Dial(v.url, v.headers); err != nil { - atomic.CompareAndSwapInt64(&v.status, on, off) - v.errLog(v.ConnectID(), err, "open connect [%s]", v.url) - return err - } else { - v.cid = resp.Header.Get("Sec-WebSocket-Accept") - } - - defer func() { - if err := resp.Body.Close(); err != nil { - v.errLog(v.ConnectID(), err, "close body connect [%s]", v.url) - } - }() - - rwlock(&v.cm, func() { - for _, fn := range v.onOpen { - fn(v.ConnectID()) - } - }) - - setupPingPong(v.connect()) - go pumpWrite(v) - pumpRead(v) - - rwlock(&v.cm, func() { - for _, fn := range v.onClose { - fn(v.ConnectID()) - } - }) - - return nil + return cc } diff --git a/plugins/web/ws_common.go b/plugins/web/ws_common.go index 4f4982c..ebcbded 100644 --- a/plugins/web/ws_common.go +++ b/plugins/web/ws_common.go @@ -6,113 +6,10 @@ package web import ( - "context" - "net/http" - "time" - - "github.com/gorilla/websocket" - "github.com/osspkg/go-sdk/errors" -) - -const ( - on = 1 - off = 0 - down = 2 + "github.com/osspkg/goppy/sdk/errors" ) var ( errServAlreadyRunning = errors.New("server already running") errServAlreadyStopped = errors.New("server already stopped") - errOneOpenConnect = errors.New("connection can be started once") ) - -/**********************************************************************************************************************/ - -func newWebsocketUpgrader() websocket.Upgrader { - return websocket.Upgrader{ - EnableCompression: true, - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(_ *http.Request) bool { - return true - }, - } -} - -/**********************************************************************************************************************/ - -const ( - pongWait = 60 * time.Second - pingPeriod = pongWait / 3 -) - -func setupPingPong(c *websocket.Conn) { - c.SetPingHandler(func(_ string) error { - return errors.Wrap( - c.SetReadDeadline(time.Now().Add(pongWait)), - //v.conn.SetWriteDeadline(time.Now().Add(pongWait)), - ) - }) - c.SetPongHandler(func(_ string) error { - return errors.Wrap( - c.SetReadDeadline(time.Now().Add(pongWait)), - //v.conn.SetWriteDeadline(time.Now().Add(pongWait)), - ) - }) -} - -/**********************************************************************************************************************/ - -type processor interface { - ConnectID() string - dataHandler(b []byte) - dataBus() <-chan []byte - connect() *websocket.Conn - cancelFunc() context.CancelFunc - done() <-chan struct{} - errLog(cid string, err error, msg string, args ...interface{}) - Close() -} - -func pumpRead(p processor) { - defer p.cancelFunc() - for { - _, message, err := p.connect().ReadMessage() - if err != nil { - if !websocket.IsCloseError(err, 1000, 1001, 1005) { - p.errLog(p.ConnectID(), err, "[ws] read message") - } - return - } - go p.dataHandler(message) - } -} - -func pumpWrite(p processor) { - ticker := time.NewTicker(pingPeriod) - defer func() { - ticker.Stop() - p.errLog(p.ConnectID(), p.connect().Close(), "close connect") - }() - for { - select { - case <-p.done(): - err := p.connect().WriteMessage(websocket.CloseMessage, - websocket.FormatCloseMessage(websocket.CloseNormalClosure, "Bye bye!")) - if err != nil && !errors.Is(err, websocket.ErrCloseSent) { - p.errLog(p.ConnectID(), err, "[ws] send close") - } - return - case m := <-p.dataBus(): - if err := p.connect().WriteMessage(websocket.TextMessage, m); err != nil { - p.errLog(p.ConnectID(), err, "[ws] send message") - return - } - case <-ticker.C: - if err := p.connect().WriteMessage(websocket.PingMessage, nil); err != nil { - p.errLog(p.ConnectID(), err, "[ws] send ping") - return - } - } - } -} diff --git a/plugins/web/ws_event.go b/plugins/web/ws_event.go deleted file mode 100644 index 0118421..0000000 --- a/plugins/web/ws_event.go +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. - * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. - */ - -package web - -//go:generate easyjson - -import ( - "encoding/json" - "sync" -) - -var ( - poolWSEvent = sync.Pool{New: func() interface{} { return &event{} }} -) - -//easyjson:json -type event struct { - ID uint `json:"e"` - Data json.RawMessage `json:"d"` - Err *string `json:"err,omitempty"` - UID json.RawMessage `json:"u,omitempty"` - Updated bool `json:"-"` -} - -func (v *event) EventID() uint { - return v.ID -} - -func (v *event) UniqueID() []byte { - if v.UID == nil { - return nil - } - result := make([]byte, 0, len(v.UID)) - return append(result, v.UID...) -} - -func (v *event) Decode(in interface{}) error { - return json.Unmarshal(v.Data, in) -} - -func (v *event) Encode(in interface{}) { - b, err := json.Marshal(in) - if err != nil { - v.Error(err) - return - } - v.Body(b) -} - -func (v *event) Reset() *event { - v.ID, v.Err, v.UID, v.Data, v.Updated = 0, nil, nil, v.Data[:0], false - return v -} - -func (v *event) Error(e error) { - if e == nil { - return - } - err := e.Error() - v.Err, v.Data, v.Updated = &err, v.Data[:0], true -} - -func (v *event) Body(b []byte) { - v.Err, v.Data, v.Updated = nil, append(v.Data[:0], b...), true -} - -func eventModel(call func(ev *event)) { - m, ok := poolWSEvent.Get().(*event) - if !ok { - m = &event{} - } - call(m) - poolWSEvent.Put(m.Reset()) -} diff --git a/plugins/web/ws_server.go b/plugins/web/ws_server.go index d6254e9..0285974 100644 --- a/plugins/web/ws_server.go +++ b/plugins/web/ws_server.go @@ -5,379 +5,76 @@ package web -//go:generate easyjson - import ( - "context" - "encoding/json" "net/http" - "sync" - "sync/atomic" - "github.com/gorilla/websocket" - context2 "github.com/osspkg/go-sdk/context" - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/log" + ws "github.com/gorilla/websocket" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/netutil/websocket" ) -type WebsocketServerOption func(upg websocket.Upgrader) - -func WebsocketServerOptionCompression(enable bool) WebsocketServerOption { - return func(upg websocket.Upgrader) { +func WebsocketServerOptionCompression(enable bool) func(ws.Upgrader) { + return func(upg ws.Upgrader) { upg.EnableCompression = enable } } -func WebsocketServerOptionBuffer(read, write int) WebsocketServerOption { - return func(upg websocket.Upgrader) { +func WebsocketServerOptionBuffer(read, write int) func(ws.Upgrader) { + return func(upg ws.Upgrader) { upg.ReadBufferSize, upg.WriteBufferSize = read, write } } -func WithWebsocketServer(options ...WebsocketServerOption) plugins.Plugin { +func WithWebsocketServer(options ...func(ws.Upgrader)) plugins.Plugin { return plugins.Plugin{ - Inject: func(l log.Logger) (*wssProvider, WebsocketServer) { - wsu := newWebsocketUpgrader() - for _, option := range options { - option(wsu) - } - wsp := newWsServerProvider(l, wsu) - return wsp, wsp + Inject: func(l log.Logger, ctx app.Context) (*wssProvider, WebsocketServer) { + wsp := newWsServerProvider(l, ctx, options...) + return wsp, wsp.serv }, } } type ( - wssProvider struct { - status int64 - clients map[string]*wssConn - events map[uint]WebsocketServerHandler - upgrade websocket.Upgrader - - ctx context.Context - cancel context.CancelFunc - - cm sync.RWMutex - em sync.RWMutex - - log log.Logger - } - - WebsocketServerHandler func(d WebsocketEventer, c WebsocketServerProcessor) error - WebsocketServer interface { - Handling(ctx Context) - Event(call WebsocketServerHandler, eid ...uint) - Broadcast(t uint, m json.Marshaler) + Handling(w http.ResponseWriter, r *http.Request) + SendEvent(eid websocket.EventID, m interface{}, cids ...string) + Broadcast(eid websocket.EventID, m interface{}) + SetHandler(call websocket.EventHandler, eids ...websocket.EventID) CloseAll() CountConn() int } -) -func newWsServerProvider(l log.Logger, wu websocket.Upgrader) *wssProvider { - c, cancel := context.WithCancel(context.TODO()) + wssProvider struct { + log log.Logger + serv *websocket.Server + sync iosync.Switch + } +) +func newWsServerProvider(l log.Logger, ctx app.Context, options ...func(ws.Upgrader)) *wssProvider { return &wssProvider{ - status: off, - clients: make(map[string]*wssConn), - events: make(map[uint]WebsocketServerHandler), - ctx: c, - cancel: cancel, - log: l, - upgrade: wu, + log: l, + serv: websocket.NewServer(l, ctx.Context(), options...), + sync: iosync.NewSwitch(), } } func (v *wssProvider) Up() error { - if !atomic.CompareAndSwapInt64(&v.status, off, on) { + if !v.sync.On() { return errServAlreadyRunning } + v.log.Infof("Websocket started") return nil } func (v *wssProvider) Down() error { - if !atomic.CompareAndSwapInt64(&v.status, on, off) { + if !v.sync.Off() { return errServAlreadyStopped } - v.CloseAll() + v.serv.CloseAll() + v.log.Infof("Websocket stopped") return nil } - -func (v *wssProvider) Broadcast(t uint, m json.Marshaler) { - eventModel(func(ev *event) { - ev.ID = t - - b, err := m.MarshalJSON() - if err != nil { - v.errLog("*", err, "[ws] Broadcast error") - return - } - ev.Body(b) - - b, err = json.Marshal(ev) - if err != nil { - v.errLog("*", err, "[ws] Broadcast error") - return - } - - v.cm.RLock() - for _, c := range v.clients { - c.Write(b) - } - v.cm.RUnlock() - }) -} - -func (v *wssProvider) CloseAll() { - v.cancel() -} - -func (v *wssProvider) Event(call WebsocketServerHandler, eid ...uint) { - lock(&v.em, func() { - for _, i := range eid { - v.events[i] = call - } - }) -} - -func (v *wssProvider) addConn(c *wssConn) { - lock(&v.cm, func() { - v.clients[c.ConnectID()] = c - }) -} - -func (v *wssProvider) delConn(id string) { - lock(&v.cm, func() { - delete(v.clients, id) - }) -} - -func (v *wssProvider) CountConn() (cc int) { - rwlock(&v.cm, func() { - cc = len(v.clients) - }) - return -} - -func (v *wssProvider) getEventHandler(id uint) (h WebsocketServerHandler, ok bool) { - rwlock(&v.em, func() { - h, ok = v.events[id] - }) - return -} - -func (v *wssProvider) errLog(cid string, err error, msg string, args ...interface{}) { - if err == nil { - return - } - v.log.WithFields(log.Fields{ - "cid": cid, - "err": err.Error(), - }).Errorf(msg, args...) -} - -func (v *wssProvider) Handling(ctx Context) { - cid := ctx.Header().Get("Sec-Websocket-Key") - - wsc, err := v.upgrade.Upgrade(ctx.Response(), ctx.Request(), nil) - if err != nil { - v.errLog(cid, err, "[ws] upgrade") - ctx.Error(http.StatusBadRequest, err) - return - } - - c := newWSSConnect(cid, v.getEventHandler, v.errLog, wsc, ctx.Context(), v.ctx) - - c.OnClose(func(cid string) { - v.delConn(cid) - }) - c.OnOpen(func(string) { - v.addConn(c) - }) - - c.Run() -} - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -type ( - wssConn struct { - status int64 - cid string - - conn *websocket.Conn - sendC chan []byte - - ctx context.Context - cancel context.CancelFunc - - onClose, onOpen []func(cid string) - erw func(cid string, err error, msg string, args ...interface{}) - event func(id uint) (WebsocketServerHandler, bool) - - mux sync.RWMutex - } - - WebsocketServerProcessor interface { - ConnectID() string - OnClose(cb func(cid string)) - OnOpen(cb func(cid string)) - Encode(eventID uint, in interface{}) - EncodeEvent(event WebsocketEventer, in interface{}) - } - - WebsocketEventer interface { - EventID() uint - UniqueID() []byte - Decode(in interface{}) error - } -) - -func newWSSConnect( - cid string, - e func(id uint) (WebsocketServerHandler, bool), - erw func(cid string, err error, msg string, args ...interface{}), - wc *websocket.Conn, - ctxs ...context.Context, -) *wssConn { - ctx, cancel := context2.Combine(ctxs...) - return &wssConn{ - status: off, - cid: cid, - ctx: ctx, - cancel: cancel, - onClose: make([]func(string), 0), - onOpen: make([]func(string), 0), - sendC: make(chan []byte, 128), - erw: erw, - event: e, - conn: wc, - } -} - -func (v *wssConn) ConnectID() string { - return v.cid -} - -func (v *wssConn) connect() *websocket.Conn { - return v.conn -} - -func (v *wssConn) cancelFunc() context.CancelFunc { - return v.cancel -} - -func (v *wssConn) done() <-chan struct{} { - return v.ctx.Done() -} - -func (v *wssConn) errLog(cid string, err error, msg string, args ...interface{}) { - v.erw(cid, err, msg, args...) -} - -func (v *wssConn) OnClose(cb func(cid string)) { - lock(&v.mux, func() { - v.onClose = append(v.onClose, cb) - }) -} - -func (v *wssConn) OnOpen(cb func(cid string)) { - lock(&v.mux, func() { - v.onOpen = append(v.onOpen, cb) - }) -} - -func (v *wssConn) Encode(eventID uint, in interface{}) { - eventModel(func(ev *event) { - ev.ID = eventID - ev.Encode(in) - b, err := json.Marshal(ev) - if err != nil { - v.errLog(v.cid, err, "[ws] encode message: %d", eventID) - return - } - v.Write(b) - }) -} - -func (v *wssConn) EncodeEvent(e WebsocketEventer, in interface{}) { - eventModel(func(ev *event) { - ev.ID = e.EventID() - ev.UID = e.UniqueID() - ev.Encode(in) - b, err := json.Marshal(ev) - if err != nil { - v.errLog(v.cid, err, "[ws] encode message: %d", e.EventID()) - return - } - v.Write(b) - }) -} - -func (v *wssConn) Write(b []byte) { - if len(b) == 0 { - return - } - - select { - case v.sendC <- b: - default: - } -} - -func (v *wssConn) dataBus() <-chan []byte { - return v.sendC -} - -func (v *wssConn) dataHandler(b []byte) { - eventModel(func(ev *event) { - if err := json.Unmarshal(b, ev); err != nil { - v.errLog(v.cid, err, "[ws] decode message") - return - } - call, ok := v.event(ev.EventID()) - if !ok { - return - } - if err := call(ev, v); err != nil { - ev.Error(err) - if bb, er := json.Marshal(ev); er != nil { - v.errLog(v.cid, errors.Wrap(err, er), "[ws] call event handler: %d", ev.EventID()) - return - } else { - v.Write(bb) - } - return - } - }) -} - -func (v *wssConn) Close() { - if !atomic.CompareAndSwapInt64(&v.status, on, down) { - return - } - v.errLog(v.ConnectID(), v.conn.Close(), "close connect") -} - -func (v *wssConn) Run() { - if !atomic.CompareAndSwapInt64(&v.status, off, on) { - return - } - - rwlock(&v.mux, func() { - for _, fn := range v.onOpen { - fn(v.ConnectID()) - } - }) - - setupPingPong(v.connect()) - go pumpWrite(v) - pumpRead(v) - - rwlock(&v.mux, func() { - for _, fn := range v.onClose { - fn(v.ConnectID()) - } - }) -} diff --git a/plugins/web/ws_server_easyjson.go b/plugins/web/ws_server_easyjson.go deleted file mode 100644 index a3b4294..0000000 --- a/plugins/web/ws_server_easyjson.go +++ /dev/null @@ -1,18 +0,0 @@ -// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. - -package web - -import ( - json "encoding/json" - easyjson "github.com/mailru/easyjson" - jlexer "github.com/mailru/easyjson/jlexer" - jwriter "github.com/mailru/easyjson/jwriter" -) - -// suppress unused package warning -var ( - _ *json.RawMessage - _ *jlexer.Lexer - _ *jwriter.Writer - _ easyjson.Marshaler -) diff --git a/plugins/web/ws_server_pool.go b/plugins/web/ws_server_pool.go index 6a02d26..ddc8dd8 100644 --- a/plugins/web/ws_server_pool.go +++ b/plugins/web/ws_server_pool.go @@ -6,19 +6,22 @@ package web import ( + "context" "sync" - "github.com/osspkg/go-sdk/errors" - "github.com/osspkg/go-sdk/log" + ws "github.com/gorilla/websocket" "github.com/osspkg/goppy/plugins" + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/netutil/websocket" ) -func WithWebsocketServerPool(options ...WebsocketServerOption) plugins.Plugin { +func WithWebsocketServerPool(options ...func(ws.Upgrader)) plugins.Plugin { return plugins.Plugin{ Inject: func(l log.Logger) (*wssPool, WebsocketServerPool) { wssp := &wssPool{ options: options, - pool: make(map[string]*wssProvider, 10), + pool: make(map[string]*websocket.Server, 10), log: l, } return wssp, wssp @@ -28,9 +31,10 @@ func WithWebsocketServerPool(options ...WebsocketServerOption) plugins.Plugin { type ( wssPool struct { - options []WebsocketServerOption - pool map[string]*wssProvider + options []func(ws.Upgrader) + pool map[string]*websocket.Server log log.Logger + ctx context.Context mux sync.Mutex } @@ -46,25 +50,13 @@ func (v *wssPool) Create(name string) WebsocketServer { if p, ok := v.pool[name]; ok { return p } - - u := newWebsocketUpgrader() - for _, option := range v.options { - option(u) - } - p := newWsServerProvider(v.log, u) + p := websocket.NewServer(v.log, v.ctx, v.options...) v.pool[name] = p - - if err := p.Up(); err != nil { - v.log.WithFields(log.Fields{ - "err": err, - "name": name, - }).Errorf("Create Websocket Server in pool") - } - return p } -func (v *wssPool) Up() error { +func (v *wssPool) Up(ctx app.Context) error { + v.ctx = ctx.Context() return nil } @@ -72,12 +64,9 @@ func (v *wssPool) Down() error { v.mux.Lock() defer v.mux.Unlock() - var err error for _, item := range v.pool { - if e := item.Down(); e != nil { - err = errors.Wrap(err, e) - } + item.CloseAll() } - return err + return nil } diff --git a/sdk/acl/acl.go b/sdk/acl/acl.go new file mode 100644 index 0000000..9d29c8c --- /dev/null +++ b/sdk/acl/acl.go @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl + +import ( + "context" + "time" + + "github.com/osspkg/goppy/sdk/errors" +) + +var ( + errFeatureGreaterMax = errors.New("feature number is greater than the maximum") + errUserNotFound = errors.New("user not found") + errChangeNotSupported = errors.New("changing ACL is not supported") +) + +type ( + ACL interface { + GetAll(email string) ([]uint8, error) + Get(email string, feature uint16) (uint8, error) + Set(email string, feature uint16, level uint8) error + Flush(email string) + AutoFlush(ctx context.Context, interval time.Duration) + } + + Storage interface { + FindACL(email string) (string, error) + ChangeACL(email, access string) error + } +) + +type _acl struct { + cache *cache + store Storage +} + +func NewACL(store Storage, size uint) ACL { + return &_acl{ + store: store, + cache: newCache(size), + } +} + +func (v *_acl) AutoFlush(ctx context.Context, interval time.Duration) { + tick := time.NewTicker(interval) + defer tick.Stop() + + for { + select { + case <-ctx.Done(): + return + case ts := <-tick.C: + v.cache.FlushByTime(ts.Unix()) + } + } +} + +func (v *_acl) GetAll(email string) ([]uint8, error) { + if !v.cache.Has(email) { + if err := v.loadFromStore(email); err != nil { + return nil, err + } + } + + return v.cache.GetAll(email) +} + +func (v *_acl) Get(email string, feature uint16) (uint8, error) { + if !v.cache.Has(email) { + if err := v.loadFromStore(email); err != nil { + return 0, err + } + } + + return v.cache.Get(email, feature) +} + +func (v *_acl) Set(email string, feature uint16, level uint8) error { + if !v.cache.Has(email) { + if err := v.loadFromStore(email); err != nil { + return err + } + } + + if err := v.cache.Set(email, feature, level); err != nil { + return err + } + return v.saveToStore(email) +} + +func (v *_acl) Flush(email string) { + v.cache.Flush(email) +} + +func (v *_acl) loadFromStore(email string) error { + access, err := v.store.FindACL(email) + if err != nil { + return errors.Wrap(err, errUserNotFound) + } + v.cache.SetAll(email, str2uint(access)...) + return nil +} + +func (v *_acl) saveToStore(email string) error { + access, err := v.cache.GetAll(email) + if err != nil { + return err + } + + err = v.store.ChangeACL(email, uint2str(access...)) + return errors.Wrapf(err, "change acl") +} diff --git a/sdk/acl/acl_test.go b/sdk/acl/acl_test.go new file mode 100644 index 0000000..9cc4a5c --- /dev/null +++ b/sdk/acl/acl_test.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl_test + +import ( + "testing" + + acl2 "github.com/osspkg/goppy/sdk/acl" + "github.com/stretchr/testify/require" +) + +func TestUnit_NewACL(t *testing.T) { + store := acl2.NewInMemoryStorage() + acl := acl2.NewACL(store, 3) + + email := "demo@example.com" + + t.Log("user not exist") + + levels, err := acl.GetAll(email) + require.Error(t, err) + require.Nil(t, levels) + + require.Error(t, acl.Set(email, 10, 1)) + + t.Log("user exist") + + require.NoError(t, store.ChangeACL(email, "")) + + require.Error(t, acl.Set(email, 10, 1)) + + levels, err = acl.GetAll(email) + require.NoError(t, err) + require.Equal(t, []uint8{0, 0, 0}, levels) + + require.NoError(t, acl.Set(email, 2, 10)) + + levels, err = acl.GetAll(email) + require.NoError(t, err) + require.Equal(t, []uint8{0, 0, 9}, levels) +} diff --git a/sdk/acl/cache.go b/sdk/acl/cache.go new file mode 100644 index 0000000..ac386e1 --- /dev/null +++ b/sdk/acl/cache.go @@ -0,0 +1,149 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl + +import ( + "sync" + "time" +) + +type ( + cache struct { + size uint + data map[string]*item + mux sync.Mutex + } + item struct { + Val []uint8 + Ts int64 + } +) + +func newCache(size uint) *cache { + return &cache{ + size: size, + data: make(map[string]*item), + } +} + +func (v *cache) Has(email string) bool { + v.mux.Lock() + defer v.mux.Unlock() + + _, ok := v.data[email] + + return ok +} + +func (v *cache) Get(email string, feature uint16) (uint8, error) { + v.mux.Lock() + defer v.mux.Unlock() + + access, ok := v.data[email] + if !ok { + return 0, errUserNotFound + } + + if feature > uint16(v.size-1) { + return 0, errFeatureGreaterMax + } + + access.Ts = time.Now().Unix() + return access.Val[feature], nil +} + +func (v *cache) GetAll(email string) ([]uint8, error) { + v.mux.Lock() + defer v.mux.Unlock() + + access, ok := v.data[email] + if !ok { + return nil, errUserNotFound + } + + access.Ts = ttl() + + tmp := make([]uint8, v.size) + for i, level := range access.Val { + if uint(i) >= v.size { + break + } + tmp[i] = validateLevel(level) + } + + return tmp, nil +} + +func (v *cache) Set(email string, feature uint16, level uint8) error { + v.mux.Lock() + defer v.mux.Unlock() + + if feature > uint16(v.size-1) { + return errFeatureGreaterMax + } + + access, ok := v.data[email] + if !ok { + access = &item{Val: make([]uint8, v.size)} + v.data[email] = access + } + + access.Ts = ttl() + access.Val[feature] = validateLevel(level) + return nil +} + +func (v *cache) SetAll(email string, levels ...uint8) { + v.mux.Lock() + defer v.mux.Unlock() + + access, ok := v.data[email] + if !ok { + access = &item{Val: make([]uint8, v.size)} + v.data[email] = access + } + + access.Ts = ttl() + for i, level := range levels { + if uint(i) >= v.size { + break + } + access.Val[i] = validateLevel(level) + } +} + +func (v *cache) Flush(email string) { + v.mux.Lock() + defer v.mux.Unlock() + + delete(v.data, email) +} + +func (v *cache) FlushByTime(ts int64) { + v.mux.Lock() + defer v.mux.Unlock() + + for email, access := range v.data { + if access.Ts < ts { + delete(v.data, email) + } + } +} + +func (v *cache) List() []string { + v.mux.Lock() + defer v.mux.Unlock() + + tmp := make([]string, 0, len(v.data)) + for email := range v.data { + tmp = append(tmp, email) + } + return tmp +} + +func ttl() int64 { + return time.Now().Add(time.Hour).Unix() +} diff --git a/sdk/acl/store_inconfig.go b/sdk/acl/store_inconfig.go new file mode 100644 index 0000000..b2f9db7 --- /dev/null +++ b/sdk/acl/store_inconfig.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl + +type ( + storeInConfig struct { + data map[string]string + } + + ConfigInConfigStorage struct { + ACL map[string]string `yaml:"acl_users"` + } +) + +func NewInConfigStorage(c *ConfigInConfigStorage) Storage { + v := &storeInConfig{} + + v.data = make(map[string]string, len(c.ACL)) + for key, val := range c.ACL { + v.data[key] = val + } + + return v +} + +func (v *storeInConfig) FindACL(email string) (string, error) { + if acl, ok := v.data[email]; ok { + return acl, nil + } + return "", errUserNotFound +} + +func (v *storeInConfig) ChangeACL(email, data string) error { + return errChangeNotSupported +} diff --git a/sdk/acl/store_inconfig_test.go b/sdk/acl/store_inconfig_test.go new file mode 100644 index 0000000..325cb3b --- /dev/null +++ b/sdk/acl/store_inconfig_test.go @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl_test + +import ( + "testing" + + "github.com/osspkg/goppy/sdk/acl" + "github.com/stretchr/testify/require" +) + +func TestUnit_NewInConfigStorage(t *testing.T) { + conf := &acl.ConfigInConfigStorage{ACL: map[string]string{ + "u1": "123", + "u2": "456", + }} + store := acl.NewInConfigStorage(conf) + require.NotNil(t, store) + + val, err := store.FindACL("u1") + require.NoError(t, err) + require.Equal(t, "123", val) + + val, err = store.FindACL("u2") + require.NoError(t, err) + require.Equal(t, "456", val) + + val, err = store.FindACL("u3") + require.Error(t, err) + require.Equal(t, "", val) + + err = store.ChangeACL("u2", "789") + require.Error(t, err) + + err = store.ChangeACL("u5", "333") + require.Error(t, err) +} diff --git a/sdk/acl/store_inmemory.go b/sdk/acl/store_inmemory.go new file mode 100644 index 0000000..1c5eb34 --- /dev/null +++ b/sdk/acl/store_inmemory.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl + +import ( + "sync" +) + +type OptionInMemoryStorage func(v *storeInMemory) + +func OptionInMemoryStorageSetupData(data map[string]string) OptionInMemoryStorage { + return func(v *storeInMemory) { + v.data = make(map[string]string, len(data)) + for key, val := range data { + v.data[key] = val + } + } +} + +type storeInMemory struct { + data map[string]string + mux sync.Mutex +} + +func NewInMemoryStorage(opts ...OptionInMemoryStorage) Storage { + v := &storeInMemory{ + data: make(map[string]string), + } + + for _, opt := range opts { + opt(v) + } + + return v +} + +func (v *storeInMemory) FindACL(email string) (string, error) { + v.mux.Lock() + defer v.mux.Unlock() + + if acl, ok := v.data[email]; ok { + return acl, nil + } + return "", errUserNotFound +} + +func (v *storeInMemory) ChangeACL(email, data string) error { + v.mux.Lock() + defer v.mux.Unlock() + + v.data[email] = data + return nil +} diff --git a/sdk/acl/store_inmemory_test.go b/sdk/acl/store_inmemory_test.go new file mode 100644 index 0000000..e15ab9b --- /dev/null +++ b/sdk/acl/store_inmemory_test.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl_test + +import ( + "testing" + + "github.com/osspkg/goppy/sdk/acl" + "github.com/stretchr/testify/require" +) + +func TestUnit_NewInMemoryStorage(t *testing.T) { + opt := acl.OptionInMemoryStorageSetupData(map[string]string{ + "u1": "123", + "u2": "456", + }) + store := acl.NewInMemoryStorage(opt) + require.NotNil(t, store) + + val, err := store.FindACL("u1") + require.NoError(t, err) + require.Equal(t, "123", val) + + val, err = store.FindACL("u2") + require.NoError(t, err) + require.Equal(t, "456", val) + + val, err = store.FindACL("u3") + require.Error(t, err) + require.Equal(t, "", val) + + err = store.ChangeACL("u2", "789") + require.NoError(t, err) + + val, err = store.FindACL("u2") + require.NoError(t, err) + require.Equal(t, "789", val) + + val, err = store.FindACL("u5") + require.Error(t, err) + require.Equal(t, "", val) + + err = store.ChangeACL("u5", "333") + require.NoError(t, err) + + val, err = store.FindACL("u5") + require.NoError(t, err) + require.Equal(t, "333", val) +} diff --git a/sdk/acl/utils.go b/sdk/acl/utils.go new file mode 100644 index 0000000..a57763b --- /dev/null +++ b/sdk/acl/utils.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package acl + +import ( + "strconv" + "strings" +) + +const MaxLevel = uint8(9) + +func validateLevel(v uint8) uint8 { + if v > MaxLevel { + return MaxLevel + } + return v +} + +func str2uint(data string) []uint8 { + t := make([]uint8, len(data)) + for i, s := range strings.Split(data, "") { + v, err := strconv.ParseUint(s, 10, 8) + if err != nil { + t[i] = 0 + continue + } + b := uint8(v) + if b > MaxLevel { + t[i] = 9 + } else { + t[i] = uint8(b) + } + } + return t +} + +func uint2str(data ...uint8) string { + t := "" + for _, v := range data { + if v > MaxLevel { + v = MaxLevel + } + t += strconv.FormatUint(uint64(v), 10) + } + return t +} diff --git a/sdk/app/README.md b/sdk/app/README.md new file mode 100644 index 0000000..4735400 --- /dev/null +++ b/sdk/app/README.md @@ -0,0 +1,136 @@ +# Application as service + +## Base config file + +***config.yaml*** + +```yaml +env: dev +level: 3 +log: /var/log/simple.log +pig: /var/run/simple.pid +``` + +level: +* 0 - error only +* 1 - + warning +* 2 - + info +* 3 - + debug + +## Example + +```go +package main + +import ( + "fmt" + + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/log" +) + +type ( + //Simple model + Simple struct{} + //Config model + Config struct { + Env string `yaml:"env"` + } +) + +//NewSimple init Simple +func NewSimple(_ Config) *Simple { + fmt.Println("--> call NewSimple") + return &Simple{} +} + +//Up method for start Simple in DI container +func (s *Simple) Up(_ app.Context) error { + fmt.Println("--> call *Simple.Up") + return nil +} + +//Down method for stop Simple in DI container +func (s *Simple) Down(_ app.Context) error { + fmt.Println("--> call *Simple.Down") + return nil +} + +func main() { + app.New(). + Logger(log.Default()). + ConfigFile( + "./config.yaml", + Config{}, + ). + Modules( + NewSimple, + ). + Run() +} +``` + +## HowTo + +***Run the app*** +```go +app.New() + .ConfigFile(, ) + .Modules() + .Run() +``` + +***Supported types for initialization*** + +* Function that returns an object or interface + +*All incoming dependencies will be injected automatically* +```go +type Simple1 struct{} +func NewSimple1(_ *log.Logger) *Simple1 { return &Simple1{} } +``` + +*Returns the interface* +```go +type Simple2 struct{} +type Simple2Interface interface{ + Get() string +} +func NewSimple2() Simple2Interface { return &Simple2{} } +func (s2 *Simple2) Get() string { + return "Hello world" +} +``` + +*If the object has the `Up(app.Context) error` and `Down() error` methods, they will be called `Up(app.Context) error` when the app starts, and `Down() error` when it finishes. This allows you to automatically start and stop routine processes inside the module* + +```go +var _ service.IServiceCtx = (*Simple3)(nil) +type Simple3 struct{} +func NewSimple3(_ *Simple4) *Simple3 { return &Simple3{} } +func (s3 *Simple3) Up(_ app.Context) error { return nil } +func (s3 *Simple3) Down(_ app.Context) error { return nil } +``` + +* Named type + +```go +type HelloWorld string +``` + +* Object structure + +```go +type Simple4 struct{ + S1 *Simple1 + S2 Simple2Interface + HW HelloWorld +} +``` + +* Object reference or type + +```go +s1 := &Simple1{} +hw := HelloWorld("Hello!!") +``` diff --git a/sdk/app/application.go b/sdk/app/application.go new file mode 100644 index 0000000..7273268 --- /dev/null +++ b/sdk/app/application.go @@ -0,0 +1,252 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "github.com/osspkg/goppy/sdk/console" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/syscall" +) + +type ( + //ENV type for environments (prod, dev, stage, etc) + ENV string + + App interface { + Logger(log log.Logger) App + Modules(modules ...interface{}) App + ConfigFile(filename string, configs ...interface{}) App + PidFile(filename string) App + Run() + Invoke(call interface{}) + ExitFunc(call func(code int)) App + } + + _app struct { + cfile string + pidfile string + configs Modules + modules Modules + sources Sources + packages *_dic + logout *_log + log log.Logger + ctx Context + exitFunc func(code int) + } +) + +// New create application +func New() App { + ctx := NewContext() + return &_app{ + modules: Modules{}, + configs: Modules{}, + packages: newDic(ctx), + ctx: ctx, + exitFunc: func(_ int) {}, + } +} + +// Logger setup logger +func (a *_app) Logger(log log.Logger) App { + a.log = log + return a +} + +// Modules append object to modules list +func (a *_app) Modules(modules ...interface{}) App { + for _, mod := range modules { + switch v := mod.(type) { + case Modules: + a.modules = a.modules.Add(v...) + default: + a.modules = a.modules.Add(v) + } + } + + return a +} + +// ConfigFile set config file path and configs models +func (a *_app) ConfigFile(filename string, configs ...interface{}) App { + a.cfile = filename + for _, config := range configs { + a.configs = a.configs.Add(config) + } + + return a +} + +func (a *_app) PidFile(filename string) App { + a.pidfile = filename + return a +} + +func (a *_app) ExitFunc(v func(code int)) App { + a.exitFunc = v + return a +} + +// Run application +func (a *_app) Run() { + a.prepareConfig(false) + + result := a.steps( + []step{ + { + Message: "Registering dependencies", + Call: func() error { return a.packages.Register(a.modules...) }, + }, + { + Message: "Running dependencies", + Call: func() error { return a.packages.Build() }, + }, + }, + func(er bool) { + if er { + a.ctx.Close() + return + } + go syscall.OnStop(a.ctx.Close) + <-a.ctx.Done() + }, + []step{ + { + Message: "Stop dependencies", + Call: func() error { return a.packages.Down() }, + }, + }, + ) + console.FatalIfErr(a.logout.Close(), "close log file") + if result { + a.exitFunc(1) + } + a.exitFunc(0) +} + +// Invoke run application +func (a *_app) Invoke(call interface{}) { + a.prepareConfig(true) + + result := a.steps( + []step{ + { + Call: func() error { return a.packages.Register(a.modules...) }, + }, + { + Call: func() error { return a.packages.Invoke(call) }, + }, + }, + func(_ bool) {}, + []step{ + { + Call: func() error { return a.packages.Down() }, + }, + }, + ) + console.FatalIfErr(a.logout.Close(), "close log file") + if result { + a.exitFunc(1) + } + a.exitFunc(0) +} + +func (a *_app) prepareConfig(interactive bool) { + var err error + if len(a.cfile) == 0 { + a.logout = newLog(&Config{ + Level: 4, + LogFile: "/dev/stdout", + }) + a.log = log.Default() + a.logout.Handler(a.log) + } + if len(a.cfile) > 0 { + // read config file + a.sources = Sources(a.cfile) + + // init logger + config := &Config{} + if err = a.sources.Decode(config); err != nil { + console.FatalIfErr(err, "decode config file: %s", a.cfile) + } + if interactive { + config.Level = 4 + config.LogFile = "/dev/stdout" + } + a.logout = newLog(config) + if a.log == nil { + a.log = log.Default() + } + a.logout.Handler(a.log) + a.modules = a.modules.Add( + ENV(config.Env), + ) + // decode all configs + var configs []interface{} + configs, err = typingRefPtr(a.configs, func(i interface{}) error { + return a.sources.Decode(i) + }) + if err != nil { + a.log.WithFields(log.Fields{ + "err": err.Error(), + }).Fatalf("Decode config file") + } + a.modules = a.modules.Add(configs...) + + if !interactive && len(a.pidfile) > 0 { + if err = syscall.Pid(a.pidfile); err != nil { + a.log.WithFields(log.Fields{ + "err": err.Error(), + "file": a.pidfile, + }).Fatalf("Create pid file") + } + } + } + a.modules = a.modules.Add( + func() log.Logger { return a.log }, + func() Context { return a.ctx }, + ) +} + +type step struct { + Call func() error + Message string +} + +func (a *_app) steps(up []step, wait func(bool), down []step) bool { + var erc int + + for _, s := range up { + if len(s.Message) > 0 { + a.log.Infof(s.Message) + } + if err := s.Call(); err != nil { + a.log.WithFields(log.Fields{ + "err": err.Error(), + }).Errorf(s.Message) + erc++ + break + } + } + + wait(erc > 0) + + for _, s := range down { + if len(s.Message) > 0 { + a.log.Infof(s.Message) + } + if err := s.Call(); err != nil { + a.log.WithFields(log.Fields{ + "err": err.Error(), + }).Errorf(s.Message) + erc++ + } + } + + return erc > 0 +} diff --git a/sdk/app/application_test.go b/sdk/app/application_test.go new file mode 100644 index 0000000..1711421 --- /dev/null +++ b/sdk/app/application_test.go @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app_test + +import ( + "testing" + + "github.com/osspkg/goppy/sdk/app" + "github.com/stretchr/testify/require" +) + +func TestUnit_AppInvoke(t *testing.T) { + out := "" + call1 := func(s *Struct1) { + s.Do(&out) + out += "Done" + } + app.New().Modules( + &Struct1{}, &Struct2{}, + ).Invoke(call1) + require.Equal(t, "[Struct1.Do]Done", out) + + out = "" + call1 = func(s *Struct1) { + s.Do2(&out) + out += "Done" + } + app.New().ExitFunc(func(code int) { + t.Log("Exit Code", code) + require.Equal(t, 0, code) + }).Modules( + NewStruct1, &Struct2{}, + ).Invoke(call1) + require.Equal(t, "[Struct1.Do][Struct2.Do]Done", out) +} + +type Struct1 struct{ s *Struct2 } + +func NewStruct1(s2 *Struct2) *Struct1 { + return &Struct1{s: s2} +} +func (*Struct1) Do(v *string) { *v += "[Struct1.Do]" } +func (s *Struct1) Do2(v *string) { + *v += "[Struct1.Do]" + s.s.Do(v) +} + +type Struct2 struct{} + +func (*Struct2) Do(v *string) { *v += "[Struct2.Do]" } diff --git a/sdk/app/config.go b/sdk/app/config.go new file mode 100644 index 0000000..6ca8a98 --- /dev/null +++ b/sdk/app/config.go @@ -0,0 +1,13 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +// Config config model +type Config struct { + Env string `yaml:"env"` + Level uint32 `yaml:"level"` + LogFile string `yaml:"log"` +} diff --git a/sdk/app/container.go b/sdk/app/container.go new file mode 100644 index 0000000..8c42df4 --- /dev/null +++ b/sdk/app/container.go @@ -0,0 +1,415 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "fmt" + "reflect" + "sync" + + "github.com/osspkg/go-algorithms/graph/kahn" + "github.com/osspkg/goppy/sdk/errors" +) + +type _dic struct { + kahn *kahn.Graph + srv *_serv + list *dicMap +} + +func newDic(ctx Context) *_dic { + return &_dic{ + kahn: kahn.New(), + srv: newService(ctx), + list: newDicMap(), + } +} + +// Down - stop all services in dependencies +func (v *_dic) Down() error { + return v.srv.Down() +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Register - register a new dependency +func (v *_dic) Register(items ...interface{}) error { + if v.srv.IsUp() { + return errDepBuilderNotRunning + } + + for _, item := range items { + ref := reflect.TypeOf(item) + switch ref.Kind() { + + case reflect.Struct: + if err := v.list.Add(item, item, typeExist); err != nil { + return err + } + + case reflect.Func: + for i := 0; i < ref.NumIn(); i++ { + in := ref.In(i) + if in.Kind() == reflect.Struct { + if err := v.list.Add(in, reflect.New(in).Elem().Interface(), typeNewIfNotExist); err != nil { + return err + } + } + + } + if ref.NumOut() == 0 { + if err := v.list.Add(ref, item, typeNew); err != nil { + return err + } + continue + } + for i := 0; i < ref.NumOut(); i++ { + if err := v.list.Add(ref.Out(i), item, typeNew); err != nil { + return err + } + } + + default: + if err := v.list.Add(item, item, typeExist); err != nil { + return err + } + } + } + + return nil +} + +// Build - initialize dependencies +func (v *_dic) Build() error { + if err := v.srv.MakeAsUp(); err != nil { + return err + } + + err := v.list.foreach(v.calcFunc, v.calcStruct, v.calcOther) + if err != nil { + return errors.Wrapf(err, "building dependency graph") + } + + if err = v.kahn.Build(); err != nil { + return errors.Wrapf(err, "dependency graph calculation") + } + + return v.exec() +} + +// Inject - obtained dependence +func (v *_dic) Inject(item interface{}) error { + _, err := v.callArgs(item) + return err +} + +// Invoke - obtained dependence +func (v *_dic) Invoke(item interface{}) error { + ref := reflect.TypeOf(item) + addr, ok := getRefAddr(ref) + if !ok { + return fmt.Errorf("resolve invoke reference") + } + + if err := v.Register(item); err != nil { + return err + } + + if err := v.srv.MakeAsUp(); err != nil { + return err + } + + err := v.list.foreach(v.calcFunc, v.calcStruct, v.calcOther) + if err != nil { + return errors.Wrapf(err, "building dependency graph") + } + + v.kahn.BreakPoint(addr) + + if err = v.kahn.Build(); err != nil { + return errors.Wrapf(err, "dependency graph calculation") + } + + return v.exec() +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +var empty = "EMPTY" + +func (v *_dic) calcFunc(outAddr string, outRef reflect.Type) error { + if outRef.NumIn() == 0 { + if err := v.kahn.Add(empty, outAddr); err != nil { + return errors.Wrapf(err, "cant add [->%s] to graph", outAddr) + } + } + + for i := 0; i < outRef.NumIn(); i++ { + inRef := outRef.In(i) + inAddr, _ := getRefAddr(inRef) + + //TODO: need? + //if _, err := v.list.Get(inAddr); err != nil { + // return errors.Wrapf(err, "cant add [%s->%s] to graph", inAddr, outAddr) + //} + if err := v.kahn.Add(inAddr, outAddr); err != nil { + return errors.Wrapf(err, "cant add [%s->%s] to graph", inAddr, outAddr) + } + } + + return nil +} + +func (v *_dic) calcStruct(outAddr string, outRef reflect.Type) error { + if outRef.NumField() == 0 { + if err := v.kahn.Add(empty, outAddr); err != nil { + return errors.Wrapf(err, "cant add [->%s] to graph", outAddr) + } + return nil + } + for i := 0; i < outRef.NumField(); i++ { + inRef := outRef.Field(i).Type + inAddr, _ := getRefAddr(inRef) + + //TODO: need? + //if _, err := v.list.Get(inAddr); err != nil { + // return errors.Wrapf(err, "cant add [%s->%s] to graph", inAddr, outAddr) + //} + if err := v.kahn.Add(inAddr, outAddr); err != nil { + return errors.Wrapf(err, "cant add [%s->%s] to graph", inAddr, outAddr) + } + } + return nil +} + +func (v *_dic) calcOther(_ string, _ reflect.Type) error { + return nil +} + +func (v *_dic) callFunc(item interface{}) ([]reflect.Value, error) { + ref := reflect.TypeOf(item) + args := make([]reflect.Value, 0, ref.NumIn()) + + for i := 0; i < ref.NumIn(); i++ { + inRef := ref.In(i) + inAddr, _ := getRefAddr(inRef) + vv, err := v.list.Get(inAddr) + if err != nil { + return nil, err + } + args = append(args, reflect.ValueOf(vv)) + } + + args = reflect.ValueOf(item).Call(args) + for _, arg := range args { + if err, ok := arg.Interface().(error); ok && err != nil { + return nil, err + } + } + + return args, nil +} + +func (v *_dic) callStruct(item interface{}) ([]reflect.Value, error) { + ref := reflect.TypeOf(item) + value := reflect.New(ref) + args := make([]reflect.Value, 0, ref.NumField()) + + for i := 0; i < ref.NumField(); i++ { + inRef := ref.Field(i) + inAddr, _ := getRefAddr(inRef.Type) + vv, err := v.list.Get(inAddr) + if err != nil { + return nil, err + } + value.Elem().FieldByName(inRef.Name).Set(reflect.ValueOf(vv)) + } + + return append(args, value.Elem()), nil +} + +func (v *_dic) callArgs(item interface{}) ([]reflect.Value, error) { + ref := reflect.TypeOf(item) + + switch ref.Kind() { + case reflect.Func: + return v.callFunc(item) + case reflect.Struct: + return v.callStruct(item) + default: + return []reflect.Value{reflect.ValueOf(item)}, nil + } +} + +func (v *_dic) exec() error { + names := make(map[string]struct{}) + for _, name := range v.kahn.Result() { + if name == empty { + continue + } + names[name] = struct{}{} + } + + for _, name := range v.kahn.Result() { + if _, ok := names[name]; !ok { + continue + } + if v.list.HasType(name, typeExist) { + continue + } + + item, err := v.list.Get(name) + if err != nil { + return err + } + + args, err := v.callArgs(item) + if err != nil { + return errors.Wrapf(err, "initialize error [%s]", name) + } + + for _, arg := range args { + addr, _ := getRefAddr(arg.Type()) + if vv, ok := asService(arg); ok { + if err = v.srv.AddAndRun(vv); err != nil { + return errors.Wrapf(err, "service initialization error [%s]", addr) + } + } + if vv, ok := asServiceContext(arg); ok { + if err = v.srv.AddAndRun(vv); err != nil { + return errors.Wrapf(err, "service initialization error [%s]", addr) + } + } + delete(names, addr) + if arg.Type().String() == "error" { + continue + } + if err = v.list.Add(arg.Type(), arg.Interface(), typeExist); err != nil { + return errors.Wrapf(err, "initialize error [%s]", addr) + } + } + delete(names, name) + } + + v.srv.IterateOver() + + return nil +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +const ( + typeNew int = iota + typeNewIfNotExist + typeExist +) + +type ( + dicMapItem struct { + Value interface{} + Type int + } + dicMap struct { + data map[string]*dicMapItem + mux sync.RWMutex + } +) + +func newDicMap() *dicMap { + return &dicMap{ + data: make(map[string]*dicMapItem), + } +} + +func (v *dicMap) Add(place, value interface{}, t int) error { + v.mux.Lock() + defer v.mux.Unlock() + + ref, ok := place.(reflect.Type) + if !ok { + ref = reflect.TypeOf(place) + } + + addr, ok := getRefAddr(ref) + if !ok { + if addr != "error" { + return fmt.Errorf("dependency [%s] is not supported", addr) + } + //return nil + } + + if vv, ok := v.data[addr]; ok { + if t == typeNewIfNotExist { + return nil + } + if vv.Type == typeExist { + return fmt.Errorf("dependency [%s] already initiated", addr) + } + } + v.data[addr] = &dicMapItem{ + Value: value, + Type: t, + } + + return nil +} + +func (v *dicMap) Get(addr string) (interface{}, error) { + v.mux.RLock() + defer v.mux.RUnlock() + + if vv, ok := v.data[addr]; ok { + return vv.Value, nil + } + return nil, fmt.Errorf("dependency [%s] not initiated", addr) +} + +func (v *dicMap) HasType(addr string, t int) bool { + v.mux.RLock() + defer v.mux.RUnlock() + + if vv, ok := v.data[addr]; ok { + return vv.Type == t + } + return false +} + +func (v *dicMap) Step(addr string) (int, error) { + v.mux.RLock() + defer v.mux.RUnlock() + + if vv, ok := v.data[addr]; ok { + return vv.Type, nil + } + return 0, fmt.Errorf("dependency [%s] not initiated", addr) +} + +func (v *dicMap) foreach(kFunc, kStruct, kOther func(addr string, ref reflect.Type) error) error { + v.mux.RLock() + defer v.mux.RUnlock() + + for addr, item := range v.data { + if item.Type == typeExist { + continue + } + + ref := reflect.TypeOf(item.Value) + var err error + switch ref.Kind() { + case reflect.Func: + err = kFunc(addr, ref) + case reflect.Struct: + err = kStruct(addr, ref) + default: + err = kOther(addr, ref) + } + + if err != nil { + return err + } + } + return nil +} diff --git a/sdk/app/container_test.go b/sdk/app/container_test.go new file mode 100644 index 0000000..0a2498b --- /dev/null +++ b/sdk/app/container_test.go @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type t0 struct{} + +func newT0() *t0 { return &t0{} } +func (t0 *t0) Up() error { return nil } +func (t0 *t0) Down() error { return nil } +func (t0 *t0) V() string { return "t0V" } + +type t1 struct { + t0 *t0 +} + +func newT1(t0 *t0) *t1 { return &t1{t0: t0} } +func (t1 *t1) Up() error { return nil } +func (t1 *t1) Down() error { return nil } +func (t1 *t1) V() string { return "t1V" } + +type t2 struct { + t0 *t0 + t1 *t1 +} + +func newT2(t1 *t1, t0 *t0) *t2 { return &t2{t0: t0, t1: t1} } +func (t2 *t2) Up() error { return nil } +func (t2 *t2) Down() error { return nil } +func (t2 *t2) V() (string, string, string) { return "t2V", t2.t1.V(), t2.t0.V() } + +type t4 struct { + T0 *t0 + T1 *t1 + T2 *t2 + T7 *t7 + T44 t44 +} + +type t44 struct { + Env string +} + +type t5 struct{} + +func newT5() *t5 { return &t5{} } +func (t5 *t5) V() string { return "t5V" } + +type t6 struct{ T4 t4 } + +func newT6(t4 t4) *t6 { return &t6{T4: t4} } +func (t6 *t6) V() string { return "t6V" } + +type t7 struct{} + +func newT7() *t7 { return &t7{} } +func (t7 *t7) V() string { return "t7V" } + +type t8 struct{} + +func newT8() (*t8, error) { return &t8{}, nil } +func (t8 *t8) V() string { return "t8V" } + +type hello string + +var AA = hello("hhhh") + +type ii interface { + V() string +} + +func newT7i(_ hello) ii { + return &t7{} +} + +func TestUnit_Dependencies(t *testing.T) { + dep := newDic(NewContext()) + + require.NoError(t, dep.Register([]interface{}{ + newT1, newT2, newT5, newT6, newT7(), newT8, + AA, newT7i, newT0, t44{Env: "aaa"}, + func(b *t6) { + t.Log("anonymous function") + }, + }...)) + + require.NoError(t, dep.Build()) + require.Error(t, dep.Build()) + + require.NoError(t, dep.Inject(func( + v1 *t1, v2 *t2, v3 *t5, v4 *t6, v5 *t6, + v6 *t7, v7 *t8, v8 hello, + v9 ii, v10 *t0, v11 t44, + ) { + require.Equal(t, "t1V", v1.V()) + vv2, _, _ := v2.V() + require.Equal(t, "t2V", vv2) + require.Equal(t, "t5V", v3.V()) + require.Equal(t, "t6V", v4.V()) + require.Equal(t, "t6V", v5.V()) + require.Equal(t, "t7V", v6.V()) + require.Equal(t, hello("hhhh"), v8) + require.Equal(t, "t7V", v9.V()) + require.Equal(t, "t0V", v10.V()) + require.Equal(t, "aaa", v11.Env) + })) + + require.Error(t, dep.Inject(func(a string, b int, c bool) { + + })) + + require.NoError(t, dep.Down()) + require.Error(t, dep.Down()) +} + +type demo1 struct{} +type demo2 struct{} +type demo3 struct{} + +func newDemo() (*demo1, *demo2, *demo3) { return &demo1{}, &demo2{}, &demo3{} } +func (d *demo1) Up() error { + fmt.Println("demo1 up") + return nil +} +func (d *demo1) Down() error { + fmt.Println("demo1 down") + return nil +} + +func TestUnit_Dependencies2(t *testing.T) { + dep := newDic(NewContext()) + require.NoError(t, dep.Register([]interface{}{ + newDemo, + }...)) + require.NoError(t, dep.Build()) + require.Error(t, dep.Build()) + require.NoError(t, dep.Down()) + require.Error(t, dep.Down()) +} + +type demo4 struct{} + +func newDemo4() (*demo4, error) { return nil, fmt.Errorf("fail init constructor demo4") } + +func TestUnit_Dependencies3(t *testing.T) { + dep := newDic(NewContext()) + require.NoError(t, dep.Register([]interface{}{ + newDemo4, + }...)) + err := dep.Build() + require.Error(t, err) + fmt.Println(err.Error()) + require.Contains(t, err.Error(), "fail init constructor demo4") +} + +func newDemo5() error { return fmt.Errorf("fail init constructor newDemo5") } + +func TestUnit_Dependencies4(t *testing.T) { + dep := newDic(NewContext()) + require.NoError(t, dep.Register(newDemo5)) + err := dep.Build() + require.Error(t, err) + fmt.Println(err.Error()) + require.Contains(t, err.Error(), "fail init constructor newDemo5") +} + +type demo6 struct{} + +func newDemo6() *demo6 { return &demo6{} } +func (d *demo6) Up() error { + fmt.Println("demo6 up") + return nil +} +func (d *demo6) Down() error { + fmt.Println("demo6 down") + return nil +} +func (d *demo6) Name() string { + return "DEMO 6" +} + +func TestUnit_DicInvoke1(t *testing.T) { + dep := newDic(NewContext()) + require.NoError(t, dep.Register(newDemo6)) + require.NoError(t, dep.Invoke(func(d *demo6) { + fmt.Println("Invoke", d.Name()) + })) + require.NoError(t, dep.Down()) + require.Error(t, dep.Down()) +} + +func TestUnit_DicInvoke2(t *testing.T) { + dep := newDic(NewContext()) + require.NoError(t, dep.Register(newDemo6)) + require.NoError(t, dep.Invoke(func() { + fmt.Println("Invoke") + })) + require.NoError(t, dep.Down()) + require.Error(t, dep.Down()) +} diff --git a/sdk/app/ctx.go b/sdk/app/ctx.go new file mode 100644 index 0000000..31ee4b6 --- /dev/null +++ b/sdk/app/ctx.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import "context" + +type ( + _ctx struct { + ctx context.Context + cancel context.CancelFunc + } + + //Context model for force close application + Context interface { + Close() + Context() context.Context + Done() <-chan struct{} + } +) + +func NewContext() Context { + ctx, cancel := context.WithCancel(context.Background()) + + return &_ctx{ + ctx: ctx, + cancel: cancel, + } +} + +// Close context close method +func (v *_ctx) Close() { + v.cancel() +} + +// Context general context +func (v *_ctx) Context() context.Context { + return v.ctx +} + +// Done context close wait channel +func (v *_ctx) Done() <-chan struct{} { + return v.ctx.Done() +} diff --git a/sdk/app/errors.go b/sdk/app/errors.go new file mode 100644 index 0000000..76d207b --- /dev/null +++ b/sdk/app/errors.go @@ -0,0 +1,15 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import "github.com/osspkg/goppy/sdk/errors" + +var ( + errDepBuilderNotRunning = errors.New("dependencies builder is not running") + errDepNotRunning = errors.New("dependencies are not running yet") + errServiceUnknown = errors.New("unknown service") + errBadFileFormat = errors.New("is not a supported file format") +) diff --git a/sdk/app/logger.go b/sdk/app/logger.go new file mode 100644 index 0000000..4b8214a --- /dev/null +++ b/sdk/app/logger.go @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "os" + + "github.com/osspkg/goppy/sdk/log" +) + +type _log struct { + file *os.File + handler log.Logger + conf *Config +} + +func newLog(conf *Config) *_log { + file, err := os.OpenFile(conf.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) + if err != nil { + panic(err) + } + return &_log{file: file, conf: conf} +} + +func (v *_log) Handler(l log.Logger) { + v.handler = l + v.handler.SetOutput(v.file) + v.handler.SetLevel(v.conf.Level) +} + +func (v *_log) Close() error { + if v.handler != nil { + v.handler.Close() + } + return v.file.Close() +} diff --git a/sdk/app/modules.go b/sdk/app/modules.go new file mode 100644 index 0000000..b0cdc75 --- /dev/null +++ b/sdk/app/modules.go @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +// Modules DI container +type Modules []interface{} + +// Add object to container +func (m Modules) Add(v ...interface{}) Modules { + for _, mod := range v { + switch v := mod.(type) { + case Modules: + m = m.Add(v...) + default: + m = append(m, mod) + } + } + return m +} diff --git a/sdk/app/modules_test.go b/sdk/app/modules_test.go new file mode 100644 index 0000000..9c0efdd --- /dev/null +++ b/sdk/app/modules_test.go @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app_test + +import ( + "testing" + + application "github.com/osspkg/goppy/sdk/app" + "github.com/stretchr/testify/require" +) + +func TestUnit_Modules(t *testing.T) { + tmp1 := application.Modules{8, 9, "W"} + tmp2 := application.Modules{18, 19, "aW", tmp1} + main := application.Modules{1, 2, "qqq"}.Add(tmp2).Add(99) + + require.Equal(t, application.Modules{1, 2, "qqq", 18, 19, "aW", 8, 9, "W", 99}, main) +} diff --git a/sdk/app/reflect.go b/sdk/app/reflect.go new file mode 100644 index 0000000..732a9fb --- /dev/null +++ b/sdk/app/reflect.go @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "fmt" + "reflect" +) + +var errType = reflect.TypeOf(new(error)).Elem() + +func getRefAddr(t reflect.Type) (string, bool) { + if len(t.PkgPath()) > 0 { + return t.PkgPath() + "." + t.Name(), true + } + switch t.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice: + if t.Implements(errType) { + return "error", false + } + if len(t.Elem().PkgPath()) > 0 { + return t.Elem().PkgPath() + "." + t.Elem().Name(), true + } + case reflect.Func: + // TODO: fix for anonymous function + // random.String(30) + "." + t.String(), true + return t.String(), true + } + return t.String(), false +} + +func typingRefPtr(vv []interface{}, call func(interface{}) error) ([]interface{}, error) { + result := make([]interface{}, 0, len(vv)) + for _, v := range vv { + ref := reflect.TypeOf(v) + switch ref.Kind() { + case reflect.Struct: + in := reflect.New(ref).Interface() + if err := call(in); err != nil { + return nil, err + } + rv := reflect.ValueOf(in).Elem().Interface() + result = append(result, rv) + case reflect.Ptr: + if err := call(v); err != nil { + return nil, err + } + result = append(result, v) + default: + return nil, fmt.Errorf("supported type [%T]", v) + } + } + return result, nil +} diff --git a/sdk/app/reflect_test.go b/sdk/app/reflect_test.go new file mode 100644 index 0000000..48b3812 --- /dev/null +++ b/sdk/app/reflect_test.go @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "reflect" + "strings" + "testing" + + "github.com/osspkg/goppy/sdk/errors" +) + +func TestUnit_getRefAddr(t *testing.T) { + type ( + aa string + bb struct{} + ff func(_ string) bool + ) + var ( + a = 0 + b = "0" + c = false + d = aa("aaa") + e ff = func(_ string) bool { return false } + f = func(_ string) bool { return false } + g = errors.New("") + h = []string{} + j = bb{} + k = struct{}{} + ) + + tests := []struct { + name string + args reflect.Type + want string + ok bool + }{ + {name: "Case1", args: reflect.TypeOf(a), want: "int"}, + {name: "Case2", args: reflect.TypeOf(b), want: "string"}, + {name: "Case3", args: reflect.TypeOf(c), want: "bool"}, + {name: "Case4", args: reflect.TypeOf(d), want: "github.com/osspkg/goppy/sdk/app.aa", ok: true}, + {name: "Case5", args: reflect.TypeOf(e), want: "github.com/osspkg/goppy/sdk/app.ff", ok: true}, + {name: "Case6", args: reflect.TypeOf(f), want: "func(string) bool", ok: true}, + {name: "Case7", args: reflect.TypeOf(g), want: "error"}, + {name: "Case8", args: reflect.TypeOf(h), want: "[]string"}, + {name: "Case9", args: reflect.TypeOf(j), want: "github.com/osspkg/goppy/sdk/app.bb", ok: true}, + {name: "Case10", args: reflect.TypeOf(k), want: "struct {}"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := getRefAddr(tt.args) + if !strings.Contains(got, tt.want) { + t.Errorf("getRefAddr() = %v, want %v", got, tt.want) + } + if ok != tt.ok { + t.Errorf("getRefAddr() = %v, want %v", ok, tt.ok) + } + }) + } +} diff --git a/sdk/app/services.go b/sdk/app/services.go new file mode 100644 index 0000000..053ad69 --- /dev/null +++ b/sdk/app/services.go @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "reflect" + "sync/atomic" + + "github.com/osspkg/goppy/sdk/errors" +) + +type ( + // ServiceInterface interface for services + ServiceInterface interface { + Up() error + Down() error + } + //ServiceContextInterface interface for services with context + ServiceContextInterface interface { + Up(ctx Context) error + Down() error + } +) + +var ( + srvType = reflect.TypeOf(new(ServiceInterface)).Elem() + srvTypeCtx = reflect.TypeOf(new(ServiceContextInterface)).Elem() +) + +func asService(v reflect.Value) (ServiceInterface, bool) { + if v.Type().AssignableTo(srvType) { + return v.Interface().(ServiceInterface), true + } + return nil, false +} + +func asServiceContext(v reflect.Value) (ServiceContextInterface, bool) { + if v.Type().AssignableTo(srvTypeCtx) { + return v.Interface().(ServiceContextInterface), true + } + return nil, false +} + +func isService(v interface{}) bool { + if _, ok := v.(ServiceInterface); ok { + return true + } + if _, ok := v.(ServiceContextInterface); ok { + return true + } + return false +} + +/**********************************************************************************************************************/ + +const ( + statusUp uint32 = 1 + statusDown uint32 = 0 +) + +type ( + _serv struct { + tree *treeItem + status uint32 + ctx Context + } + treeItem struct { + Previous *treeItem + Current interface{} + Next *treeItem + } +) + +func newService(ctx Context) *_serv { + return &_serv{ + tree: nil, + status: statusDown, + ctx: ctx, + } +} + +// IsUp - mark that all services have started +func (s *_serv) IsUp() bool { + return atomic.LoadUint32(&s.status) == statusUp +} + +// AddAndRun - add new service by interface +func (s *_serv) AddAndRun(v interface{}) error { + if !s.IsUp() { + return errDepBuilderNotRunning + } + + if !isService(v) { + return errors.Wrapf(errServiceUnknown, "service [%T]", v) + } + + if s.tree == nil { + s.tree = &treeItem{ + Previous: nil, + Current: v, + Next: nil, + } + } else { + n := &treeItem{ + Previous: s.tree, + Current: v, + Next: nil, + } + n.Previous.Next = n + s.tree = n + } + + if vv, ok := v.(ServiceContextInterface); ok { + if err := vv.Up(s.ctx); err != nil { + return err + } + } + if vv, ok := v.(ServiceInterface); ok { + if err := vv.Up(); err != nil { + return err + } + } + + return nil +} + +func (s *_serv) MakeAsUp() error { + if !atomic.CompareAndSwapUint32(&s.status, statusDown, statusUp) { + return errDepBuilderNotRunning + } + return nil +} + +func (s *_serv) IterateOver() { + if s.tree == nil { + return + } + for s.tree.Previous != nil { + s.tree = s.tree.Previous + } + for { + if s.tree.Next == nil { + break + } + s.tree = s.tree.Next + } + return +} + +// Down - stop all services +func (s *_serv) Down() error { + var err0 error + if !atomic.CompareAndSwapUint32(&s.status, statusUp, statusDown) { + return errDepNotRunning + } + if s.tree == nil { + return nil + } + for { + if vv, ok := s.tree.Current.(ServiceContextInterface); ok { + if err := vv.Down(); err != nil { + err0 = errors.Wrap(err0, + errors.Wrapf(err, "down [%T] service error", s.tree.Current), + ) + } + } else if vv, ok := s.tree.Current.(ServiceInterface); ok { + if err := vv.Down(); err != nil { + err0 = errors.Wrap(err0, + errors.Wrapf(err, "down [%T] service error", s.tree.Current), + ) + } + } else { + return errors.Wrapf(errServiceUnknown, "service [%T]", s.tree.Current) + } + if s.tree.Previous == nil { + break + } + s.tree = s.tree.Previous + } + for s.tree.Next != nil { + s.tree = s.tree.Next + } + return err0 +} diff --git a/sdk/app/sources.go b/sdk/app/sources.go new file mode 100644 index 0000000..65f8ff9 --- /dev/null +++ b/sdk/app/sources.go @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package app + +import ( + "encoding/json" + "os" + "path/filepath" + + "github.com/osspkg/goppy/sdk/errors" + "gopkg.in/yaml.v3" +) + +// Sources model +type Sources string + +// Decode unmarshal file to model +func (v Sources) Decode(configs ...interface{}) error { + data, err := os.ReadFile(string(v)) + if err != nil { + return err + } + ext := filepath.Ext(string(v)) + switch ext { + case ".yml", ".yaml": + return v.unmarshal("yaml unmarshal", data, yaml.Unmarshal, configs...) + case ".json": + return v.unmarshal("json unmarshal", data, json.Unmarshal, configs...) + } + return errBadFileFormat +} + +func (v Sources) unmarshal( + title string, data []byte, call func([]byte, interface{}, + ) error, configs ...interface{}) error { + for _, conf := range configs { + if err := call(data, conf); err != nil { + return errors.Wrapf(err, title) + } + } + return nil +} diff --git a/sdk/auth/jwt/jwt.go b/sdk/auth/jwt/jwt.go new file mode 100644 index 0000000..75acd4b --- /dev/null +++ b/sdk/auth/jwt/jwt.go @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package jwt + +//go:generate easyjson + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/sha512" + "encoding/base64" + "encoding/json" + "fmt" + "hash" + "strings" + "time" + + "github.com/osspkg/goppy/sdk/encryption/aesgcm" +) + +const ( + AlgHS256 = "HS256" + AlgHS384 = "HS384" + AlgHS512 = "HS512" +) + +type Config struct { + ID string `yaml:"id"` + Key string `yaml:"key"` + Algorithm string `yaml:"alg"` +} + +//easyjson:json +type Header struct { + Kid string `json:"kid"` + Alg string `json:"alg"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"eat"` +} + +type ( + JWT struct { + pool map[string]*Pool + } + + Pool struct { + conf Config + hash func() hash.Hash + key []byte + codec *aesgcm.Codec + } +) + +func New(conf []Config) (*JWT, error) { + obj := &JWT{pool: make(map[string]*Pool)} + + for _, c := range conf { + var h func() hash.Hash + switch c.Algorithm { + case AlgHS256: + h = sha256.New + case AlgHS384: + h = sha512.New384 + case AlgHS512: + h = sha512.New + default: + return nil, fmt.Errorf("jwt algorithm not supported") + } + codec, err := aesgcm.New(c.Key) + if err != nil { + return nil, fmt.Errorf("jwt init codec: %w", err) + } + obj.pool[c.ID] = &Pool{conf: c, hash: h, key: []byte(c.Key), codec: codec} + } + + return obj, nil +} + +func (v *JWT) rndPool() (*Pool, error) { + for _, p := range v.pool { + return p, nil + } + return nil, fmt.Errorf("jwt pool is empty") +} + +func (v *JWT) getPool(id string) (*Pool, error) { + p, ok := v.pool[id] + if ok { + return p, nil + } + return nil, fmt.Errorf("jwt pool not found") +} + +func (v *JWT) calcHash(hash func() hash.Hash, key []byte, data []byte) ([]byte, error) { + mac := hmac.New(hash, key) + if _, err := mac.Write(data); err != nil { + return nil, err + } + result := mac.Sum(nil) + return result, nil +} + +func (v *JWT) Sign(payload interface{}, ttl time.Duration) (string, error) { + pool, err := v.rndPool() + if err != nil { + return "", err + } + + rh := &Header{ + Kid: pool.conf.ID, + Alg: pool.conf.Algorithm, + IssuedAt: time.Now().Unix(), + ExpiresAt: time.Now().Add(ttl).Unix(), + } + h, err := json.Marshal(rh) + if err != nil { + return "", err + } + result := base64.StdEncoding.EncodeToString(h) + + p, err := json.Marshal(payload) + if err != nil { + return "", err + } + p, err = pool.codec.Encrypt(p) + if err != nil { + return "", err + } + result += "." + base64.StdEncoding.EncodeToString(p) + + s, err := v.calcHash(pool.hash, pool.key, []byte(result)) + if err != nil { + return "", err + } + result += "." + base64.StdEncoding.EncodeToString(s) + + return result, nil +} + +func (v *JWT) Verify(token string, payload interface{}) (*Header, error) { + data := strings.Split(token, ".") + if len(data) != 3 { + return nil, fmt.Errorf("invalid jwt format") + } + + h, err := base64.StdEncoding.DecodeString(data[0]) + if err != nil { + return nil, err + } + header := &Header{} + if err = json.Unmarshal(h, header); err != nil { + return nil, err + } + + pool, err := v.getPool(header.Kid) + if err != nil { + return nil, err + } + + if header.Alg != pool.conf.Algorithm { + return nil, fmt.Errorf("invalid jwt algorithm") + } + if header.ExpiresAt < time.Now().Unix() { + return nil, fmt.Errorf("jwt expired") + } + + expected, err := base64.StdEncoding.DecodeString(data[2]) + if err != nil { + return nil, err + } + actual, err := v.calcHash(pool.hash, pool.key, []byte(data[0]+"."+data[1])) + if err != nil { + return nil, err + } + if !hmac.Equal(expected, actual) { + return nil, fmt.Errorf("invalid jwt signature") + } + + p, err := base64.StdEncoding.DecodeString(data[1]) + if err != nil { + return nil, err + } + p, err = pool.codec.Decrypt(p) + if err != nil { + return nil, err + } + if err = json.Unmarshal(p, payload); err != nil { + return nil, err + } + + return header, nil +} diff --git a/sdk/auth/jwt/jwt_easyjson.go b/sdk/auth/jwt/jwt_easyjson.go new file mode 100644 index 0000000..05e5981 --- /dev/null +++ b/sdk/auth/jwt/jwt_easyjson.go @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package jwt + +import ( + json "encoding/json" + + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson171edd05DecodeGithubComOsspkgGoSdkAuthJwt(in *jlexer.Lexer, out *Header) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "kid": + out.Kid = string(in.String()) + case "alg": + out.Alg = string(in.String()) + case "iat": + out.IssuedAt = int64(in.Int64()) + case "eat": + out.ExpiresAt = int64(in.Int64()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson171edd05EncodeGithubComOsspkgGoSdkAuthJwt(out *jwriter.Writer, in Header) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"kid\":" + out.RawString(prefix[1:]) + out.String(string(in.Kid)) + } + { + const prefix string = ",\"alg\":" + out.RawString(prefix) + out.String(string(in.Alg)) + } + { + const prefix string = ",\"iat\":" + out.RawString(prefix) + out.Int64(int64(in.IssuedAt)) + } + { + const prefix string = ",\"eat\":" + out.RawString(prefix) + out.Int64(int64(in.ExpiresAt)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v Header) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson171edd05EncodeGithubComOsspkgGoSdkAuthJwt(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v Header) MarshalEasyJSON(w *jwriter.Writer) { + easyjson171edd05EncodeGithubComOsspkgGoSdkAuthJwt(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *Header) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson171edd05DecodeGithubComOsspkgGoSdkAuthJwt(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *Header) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson171edd05DecodeGithubComOsspkgGoSdkAuthJwt(l, v) +} diff --git a/sdk/auth/jwt/jwt_test.go b/sdk/auth/jwt/jwt_test.go new file mode 100644 index 0000000..03f1663 --- /dev/null +++ b/sdk/auth/jwt/jwt_test.go @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package jwt_test + +import ( + "testing" + "time" + + "github.com/osspkg/goppy/sdk/auth/jwt" + "github.com/stretchr/testify/require" +) + +type demoJwtPayload struct { + ID int `json:"id"` +} + +func TestUnit_NewJWT(t *testing.T) { + conf := make([]jwt.Config, 0) + conf = append(conf, jwt.Config{ID: "789456", Key: "123456789123456789123456789123456789", Algorithm: jwt.AlgHS256}) + j, err := jwt.New(conf) + require.NoError(t, err) + + payload1 := demoJwtPayload{ID: 159} + token, err := j.Sign(&payload1, time.Hour) + require.NoError(t, err) + + payload2 := demoJwtPayload{} + head1, err := j.Verify(token, &payload2) + require.NoError(t, err) + + require.Equal(t, payload1, payload2) + + head2, err := j.Verify(token, &payload2) + require.NoError(t, err) + require.Equal(t, head1, head2) +} diff --git a/sdk/auth/oauth/isp.go b/sdk/auth/oauth/isp.go new file mode 100644 index 0000000..64fc72f --- /dev/null +++ b/sdk/auth/oauth/isp.go @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package oauth + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/ioutil" + "golang.org/x/oauth2" +) + +var ( + errProviderFail = errors.New("provider not found") +) + +type ( + User interface { + GetName() string + GetEmail() string + GetIcon() string + } + + Provider interface { + Code() string + Config(conf ConfigItem) + AuthCodeURL() string + AuthCodeKey() string + Exchange(ctx context.Context, code string) (User, error) + } +) + +func (v *OAuth) AddProviders(p ...Provider) { + v.mux.Lock() + defer v.mux.Unlock() + + for _, item := range p { + for _, cp := range v.config.Provider { + if cp.Code == item.Code() { + item.Config(cp) + v.list[item.Code()] = item + } + } + } +} + +func (v *OAuth) GetProvider(name string) (Provider, error) { + v.mux.RLock() + defer v.mux.RUnlock() + + p, ok := v.list[name] + if !ok { + return nil, errProviderFail + } + return p, nil +} + +/**********************************************************************************************************************/ + +type oauth2Config interface { + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + Client(ctx context.Context, t *oauth2.Token) *http.Client +} + +func oauth2ExchangeContext( + ctx context.Context, code string, uri string, srv oauth2Config, model json.Unmarshaler, +) error { + tok, err := srv.Exchange(ctx, code) + if err != nil { + return errors.Wrapf(err, "exchange to oauth service") + } + client := srv.Client(ctx, tok) + resp, err := client.Get(uri) //nolint: bodyclose + if err != nil { + return errors.Wrapf(err, "client request to oauth service") + } + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read response from oauth service") + } + if err = json.Unmarshal(b, model); err != nil { + return errors.Wrapf(err, "decode oauth model") + } + return nil +} diff --git a/sdk/auth/oauth/isp_google.go b/sdk/auth/oauth/isp_google.go new file mode 100644 index 0000000..a8c4fc7 --- /dev/null +++ b/sdk/auth/oauth/isp_google.go @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package oauth + +//go:generate easyjson + +import ( + "context" + "encoding/json" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" +) + +const CodeGoogle = "google" + +type ( + //easyjson:json + modelGoogle struct { + Name string `json:"name"` + Icon string `json:"picture"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + } + + UserGoogle struct { + name string + icon string + email string + } +) + +func (v *UserGoogle) UnmarshalJSON(data []byte) error { + var tmp modelGoogle + if err := json.Unmarshal(data, &tmp); err != nil { + return err + } + + if tmp.EmailVerified { + v.name = tmp.Name + v.icon = tmp.Icon + v.email = tmp.Email + } + + return nil +} + +func (v *UserGoogle) GetName() string { + return v.name +} + +func (v *UserGoogle) GetIcon() string { + return v.icon +} + +func (v *UserGoogle) GetEmail() string { + return v.email +} + +/**********************************************************************************************************************/ + +type IspGoogle struct { + oauth *oauth2.Config + config configIsp +} + +func (v *IspGoogle) Code() string { + return CodeGoogle +} + +func (v *IspGoogle) Config(c ConfigItem) { + v.oauth = &oauth2.Config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + RedirectURL: c.RedirectURL, + Endpoint: google.Endpoint, + Scopes: []string{ + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + }, + } + v.config = configIsp{ + State: "state", + AuthCodeKey: "code", + RequestURL: "https://openidconnect.googleapis.com/v1/userinfo", + } +} + +func (v *IspGoogle) AuthCodeURL() string { + return v.oauth.AuthCodeURL(v.config.State) +} + +func (v *IspGoogle) AuthCodeKey() string { + return v.config.AuthCodeKey +} + +func (v *IspGoogle) Exchange(ctx context.Context, code string) (User, error) { + m := &UserGoogle{} + if err := oauth2ExchangeContext(ctx, code, v.config.RequestURL, v.oauth, m); err != nil { + return nil, err + } + return m, nil +} diff --git a/sdk/auth/oauth/isp_google_easyjson.go b/sdk/auth/oauth/isp_google_easyjson.go new file mode 100644 index 0000000..5e451d3 --- /dev/null +++ b/sdk/auth/oauth/isp_google_easyjson.go @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package oauth + +import ( + json "encoding/json" + + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson3bc980faDecodeGithubComOsspkgGoSdkAuthOauth(in *jlexer.Lexer, out *modelGoogle) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "name": + out.Name = string(in.String()) + case "picture": + out.Icon = string(in.String()) + case "email": + out.Email = string(in.String()) + case "email_verified": + out.EmailVerified = bool(in.Bool()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson3bc980faEncodeGithubComOsspkgGoSdkAuthOauth(out *jwriter.Writer, in modelGoogle) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"name\":" + out.RawString(prefix[1:]) + out.String(string(in.Name)) + } + { + const prefix string = ",\"picture\":" + out.RawString(prefix) + out.String(string(in.Icon)) + } + { + const prefix string = ",\"email\":" + out.RawString(prefix) + out.String(string(in.Email)) + } + { + const prefix string = ",\"email_verified\":" + out.RawString(prefix) + out.Bool(bool(in.EmailVerified)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v modelGoogle) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson3bc980faEncodeGithubComOsspkgGoSdkAuthOauth(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v modelGoogle) MarshalEasyJSON(w *jwriter.Writer) { + easyjson3bc980faEncodeGithubComOsspkgGoSdkAuthOauth(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *modelGoogle) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson3bc980faDecodeGithubComOsspkgGoSdkAuthOauth(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *modelGoogle) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson3bc980faDecodeGithubComOsspkgGoSdkAuthOauth(l, v) +} diff --git a/sdk/auth/oauth/isp_yandex.go b/sdk/auth/oauth/isp_yandex.go new file mode 100644 index 0000000..1be7b52 --- /dev/null +++ b/sdk/auth/oauth/isp_yandex.go @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package oauth + +//go:generate easyjson + +import ( + "context" + "encoding/json" + "fmt" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/yandex" +) + +const CodeYandex = "yandex" + +type ( + //easyjson:json + modelYandex struct { + Name string `json:"display_name"` + Icon string `json:"default_avatar_id"` + Email string `json:"default_email"` + } + + UserYandex struct { + name string + icon string + email string + } +) + +func (v *UserYandex) UnmarshalJSON(data []byte) error { + var tmp modelYandex + if err := json.Unmarshal(data, &tmp); err != nil { + return err + } + + if len(tmp.Icon) > 0 { + v.icon = fmt.Sprintf("https://avatars.yandex.net/get-yapic/%s/islands-retina-50", tmp.Icon) + } + v.name = tmp.Name + v.email = tmp.Email + + return nil +} + +func (v *UserYandex) GetName() string { + return v.name +} + +func (v *UserYandex) GetIcon() string { + return v.icon +} + +func (v *UserYandex) GetEmail() string { + return v.email +} + +/**********************************************************************************************************************/ + +type IspYandex struct { + oauth *oauth2.Config + config configIsp +} + +func (v *IspYandex) Code() string { + return CodeYandex +} + +func (v *IspYandex) Config(c ConfigItem) { + v.oauth = &oauth2.Config{ + ClientID: c.ClientID, + ClientSecret: c.ClientSecret, + RedirectURL: c.RedirectURL, + Endpoint: yandex.Endpoint, + Scopes: []string{ + "login:email", + "login:info", + "login:avatar", + }, + } + v.config = configIsp{ + State: "state", + AuthCodeKey: "code", + RequestURL: "https://login.yandex.ru/info", + } +} + +func (v *IspYandex) AuthCodeURL() string { + return v.oauth.AuthCodeURL(v.config.State) +} + +func (v *IspYandex) AuthCodeKey() string { + return v.config.AuthCodeKey +} + +func (v *IspYandex) Exchange(ctx context.Context, code string) (User, error) { + m := &UserYandex{} + if err := oauth2ExchangeContext(ctx, code, v.config.RequestURL, v.oauth, m); err != nil { + return nil, err + } + return m, nil +} diff --git a/sdk/auth/oauth/isp_yandex_easyjson.go b/sdk/auth/oauth/isp_yandex_easyjson.go new file mode 100644 index 0000000..749b7cb --- /dev/null +++ b/sdk/auth/oauth/isp_yandex_easyjson.go @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package oauth + +import ( + json "encoding/json" + + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjsonD1fc6ea8DecodeGithubComOsspkgGoSdkAuthOauth(in *jlexer.Lexer, out *modelYandex) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "display_name": + out.Name = string(in.String()) + case "default_avatar_id": + out.Icon = string(in.String()) + case "default_email": + out.Email = string(in.String()) + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjsonD1fc6ea8EncodeGithubComOsspkgGoSdkAuthOauth(out *jwriter.Writer, in modelYandex) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"display_name\":" + out.RawString(prefix[1:]) + out.String(string(in.Name)) + } + { + const prefix string = ",\"default_avatar_id\":" + out.RawString(prefix) + out.String(string(in.Icon)) + } + { + const prefix string = ",\"default_email\":" + out.RawString(prefix) + out.String(string(in.Email)) + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v modelYandex) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjsonD1fc6ea8EncodeGithubComOsspkgGoSdkAuthOauth(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v modelYandex) MarshalEasyJSON(w *jwriter.Writer) { + easyjsonD1fc6ea8EncodeGithubComOsspkgGoSdkAuthOauth(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *modelYandex) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjsonD1fc6ea8DecodeGithubComOsspkgGoSdkAuthOauth(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *modelYandex) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjsonD1fc6ea8DecodeGithubComOsspkgGoSdkAuthOauth(l, v) +} diff --git a/sdk/auth/oauth/oauth.go b/sdk/auth/oauth/oauth.go new file mode 100644 index 0000000..bc3a0ac --- /dev/null +++ b/sdk/auth/oauth/oauth.go @@ -0,0 +1,95 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package oauth + +import ( + "net/http" + "sync" +) + +/**********************************************************************************************************************/ + +type ( + ConfigItem struct { + Code string `yaml:"code"` + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + RedirectURL string `yaml:"redirect_url"` + } + + Config struct { + Provider []ConfigItem `yaml:"oauth"` + } + + configIsp struct { + State string + AuthCodeKey string + RequestURL string + } +) + +/**********************************************************************************************************************/ + +type ( + OAuth struct { + config *Config + list map[string]Provider + mux sync.RWMutex + } + + CallBack func(http.ResponseWriter, *http.Request, User) +) + +func New(c *Config) *OAuth { + return &OAuth{ + config: c, + list: make(map[string]Provider), + } +} + +func (v *OAuth) Up() error { + v.AddProviders( + &IspYandex{}, + ) + return nil +} + +func (v *OAuth) Down() error { + return nil +} + +func (v *OAuth) Request(name string) func(http.ResponseWriter, *http.Request) { + p, err := v.GetProvider(name) + if err != nil { + return func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) //nolint: errcheck + } + } + return func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, p.AuthCodeURL(), http.StatusMovedPermanently) + } +} + +func (v *OAuth) CallBack(name string, call CallBack) func(w http.ResponseWriter, r *http.Request) { + p, err := v.GetProvider(name) + if err != nil { + return func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) //nolint: errcheck + } + } + return func(w http.ResponseWriter, r *http.Request) { + code := r.URL.Query().Get(p.AuthCodeKey()) + u, err := p.Exchange(r.Context(), code) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(err.Error())) //nolint: errcheck + return + } + call(w, r, u) + } +} diff --git a/sdk/certificate/pgp/pgp.go b/sdk/certificate/pgp/pgp.go new file mode 100644 index 0000000..e837627 --- /dev/null +++ b/sdk/certificate/pgp/pgp.go @@ -0,0 +1,265 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package pgp + +import ( + "bytes" + "crypto" + "io" + "os" + + "github.com/osspkg/goppy/sdk/errors" + "golang.org/x/crypto/openpgp" + "golang.org/x/crypto/openpgp/armor" + "golang.org/x/crypto/openpgp/clearsign" + "golang.org/x/crypto/openpgp/packet" +) + +type ( + Config struct { + Name, Email, Comment string + } + + Cert struct { + Public []byte + Private []byte + } +) + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type ( + store struct { + key *openpgp.Entity + conf *packet.Config + headers map[string]string + } + + Signer interface { + SetKey(b []byte, passwd string) error + SetKeyFromFile(filename string, passwd string) error + SetHash(hash crypto.Hash, bits int) + PublicKey() ([]byte, error) + PublicKeyBase64() ([]byte, error) + Sign(in io.Reader, out io.Writer) error + } +) + +func New() Signer { + return &store{ + conf: &packet.Config{ + DefaultHash: crypto.SHA512, + RSABits: 4096, + }, + headers: make(map[string]string), + } +} + +func (v *store) SetKey(b []byte, passwd string) error { + r := bytes.NewReader(b) + return v.readKey(r, passwd) +} + +func (v *store) SetHash(hash crypto.Hash, bits int) { + v.conf.DefaultHash = hash + v.conf.RSABits = bits +} + +func (v *store) SetHeaders(headers ...string) error { + h, err := createHeaders(headers) + if err != nil { + return err + } + v.headers = mergeHeaders(v.headers, h) + return nil +} + +func (v *store) SetKeyFromFile(filename string, passwd string) error { + r, err := os.Open(filename) + if err != nil { + return errors.Wrapf(err, "read key from file") + } + return v.readKey(r, passwd) +} + +func (v *store) PublicKey() ([]byte, error) { + if v.key == nil { + return nil, errors.New("key is empty") + } + + var buf bytes.Buffer + if err := v.key.Serialize(&buf); err != nil { + return nil, errors.Wrapf(err, "serialize public key") + } + return buf.Bytes(), nil +} + +func (v *store) PublicKeyBase64() ([]byte, error) { + if v.key == nil { + return nil, errors.New("key is empty") + } + + var buf bytes.Buffer + enc, err := armor.Encode(&buf, openpgp.PublicKeyType, v.headers) + if err != nil { + return nil, errors.Wrapf(err, "init armor encoder") + } + if err = v.key.Serialize(enc); err != nil { + return nil, errors.Wrapf(err, "serialize public key") + } + if err = enc.Close(); err != nil { + return nil, errors.Wrapf(err, "close armor encoder") + } + return buf.Bytes(), nil +} + +func (v *store) readKey(r io.ReadSeeker, passwd string) error { + block, err := armor.Decode(r) + if err != nil { + return errors.Wrapf(err, "armor decode key") + } + if block.Type != openpgp.PrivateKeyType { + return errors.Wrapf(err, "invalid key type") + } + if _, err = r.Seek(0, 0); err != nil { + return errors.Wrapf(err, "seek key file") + } + keys, err := openpgp.ReadArmoredKeyRing(r) + if err != nil { + return errors.Wrapf(err, "read armored key") + } + v.key = keys[0] + if v.key.PrivateKey.Encrypted { + if err = v.key.PrivateKey.Decrypt([]byte(passwd)); err != nil { + return errors.Wrapf(err, "invalid password") + } + } + v.headers = mergeHeaders(v.headers, block.Header) + return nil +} + +func (v *store) Sign(in io.Reader, out io.Writer) error { + if v.key == nil { + return errors.New("key is empty") + } + + w, err := clearsign.Encode(out, v.key.PrivateKey, v.conf) + if err != nil { + return errors.Wrapf(err, "init") + } + if _, err = io.Copy(w, in); err != nil { + return err + } + if err = w.Close(); err != nil { + return err + } + return nil +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +func generatePrivateKey(key *openpgp.Entity, w io.Writer, headers map[string]string) error { + enc, err := armor.Encode(w, openpgp.PrivateKeyType, headers) + if err != nil { + return errors.Wrapf(err, "init armor encoder") + } + defer enc.Close() //nolint: errcheck + + if err = key.SerializePrivate(enc, nil); err != nil { + return errors.Wrapf(err, "serialize private key") + } + + return nil +} + +func generatePublicKey(key *openpgp.Entity, w io.Writer, headers map[string]string) error { + enc, err := armor.Encode(w, openpgp.PublicKeyType, headers) + if err != nil { + return errors.Wrapf(err, "create OpenPGP armor") + } + defer enc.Close() //nolint: errcheck + + if err = key.Serialize(enc); err != nil { + return errors.Wrapf(err, "serialize public key") + } + + return nil +} + +func setupIdentities(key *openpgp.Entity, c *packet.Config) error { + // Sign all the identities + for _, id := range key.Identities { + id.SelfSignature.PreferredCompression = []uint8{1, 2, 3, 0} + id.SelfSignature.PreferredHash = []uint8{2, 8, 10, 1, 3, 9, 11} + id.SelfSignature.PreferredSymmetric = []uint8{9, 8, 7, 3, 2} + + if err := id.SelfSignature.SignUserId(id.UserId.Id, key.PrimaryKey, key.PrivateKey, c); err != nil { + return err + } + } + return nil +} + +func createHeaders(v []string) (map[string]string, error) { + if len(v)%2 != 0 { + return nil, errors.New("odd headers count") + } + result := make(map[string]string, len(v)/2) + for i := 0; i < len(v); i += 2 { + result[v[i]] = v[i+1] + } + return result, nil +} + +func mergeHeaders(h ...map[string]string) map[string]string { + result := make(map[string]string) + for _, m := range h { + for k, v := range m { + result[k] = v + } + } + return result +} + +func NewCert(c Config, hash crypto.Hash, bits int, headers ...string) (*Cert, error) { + h, err := createHeaders(headers) + if err != nil { + return nil, errors.Wrapf(err, "parse headers") + } + + conf := &packet.Config{ + DefaultHash: hash, + RSABits: bits, + } + + key, err := openpgp.NewEntity(c.Name, c.Comment, c.Email, conf) + if err != nil { + return nil, errors.Wrapf(err, "generate entity") + } + + if err = setupIdentities(key, conf); err != nil { + return nil, errors.Wrapf(err, "setup entity") + } + + var priv bytes.Buffer + if err = generatePrivateKey(key, &priv, h); err != nil { + return nil, errors.Wrapf(err, "generate private key") + } + + var pub bytes.Buffer + if err = generatePublicKey(key, &pub, h); err != nil { + return nil, errors.Wrapf(err, "generate public key") + } + + return &Cert{ + Public: pub.Bytes(), + Private: priv.Bytes(), + }, nil +} + +func NewCertSHA512(c Config, headers ...string) (*Cert, error) { + return NewCert(c, crypto.SHA512, 4096, headers...) +} diff --git a/sdk/certificate/pgp/pgp_test.go b/sdk/certificate/pgp/pgp_test.go new file mode 100644 index 0000000..c9b9310 --- /dev/null +++ b/sdk/certificate/pgp/pgp_test.go @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package pgp_test + +import ( + "bytes" + "crypto" + "testing" + + "github.com/osspkg/goppy/sdk/certificate/pgp" +) + +func TestUnit_PGP(t *testing.T) { + conf := pgp.Config{ + Name: "Test Name", + Email: "Test Email", + Comment: "Test Comment", + } + crt, err := pgp.NewCert(conf, crypto.MD5, 1024, "tool", "dewep utils") + if err != nil { + t.Fatalf(err.Error()) + } + t.Log(string(crt.Private), string(crt.Public)) + + in := bytes.NewBufferString("Hello world") + out := &bytes.Buffer{} + + sig := pgp.New() + if err = sig.SetKey(crt.Private, ""); err != nil { + t.Fatalf(err.Error()) + } + sig.SetHash(crypto.MD5, 1024) + if err = sig.Sign(in, out); err != nil { + t.Fatalf(err.Error()) + } + t.Log(out.String()) +} diff --git a/sdk/certificate/x509/x509.go b/sdk/certificate/x509/x509.go new file mode 100644 index 0000000..57ca70e --- /dev/null +++ b/sdk/certificate/x509/x509.go @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package x509 + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + cx509 "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" + + "github.com/osspkg/goppy/sdk/errors" +) + +type ( + Cert struct { + Public []byte + Private []byte + } + + Config struct { + Organization string + OrganizationalUnit string + Country string + Province string + Locality string + StreetAddress string + PostalCode string + } +) + +func (v *Config) ToSubject() pkix.Name { + result := pkix.Name{} + + if len(v.Country) > 0 { + result.Country = []string{v.Country} + } + if len(v.Organization) > 0 { + result.Organization = []string{v.Organization} + } + if len(v.OrganizationalUnit) > 0 { + result.OrganizationalUnit = []string{v.OrganizationalUnit} + } + if len(v.Locality) > 0 { + result.Locality = []string{v.Locality} + } + if len(v.Province) > 0 { + result.Province = []string{v.Province} + } + if len(v.StreetAddress) > 0 { + result.StreetAddress = []string{v.StreetAddress} + } + if len(v.PostalCode) > 0 { + result.PostalCode = []string{v.PostalCode} + } + + return result +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +func generate(c *Config, ttl time.Duration, sn int64, ca *Cert, cn ...string) (*Cert, error) { + crt := &cx509.Certificate{ + SerialNumber: big.NewInt(sn), + Subject: c.ToSubject(), + NotBefore: time.Now(), + NotAfter: time.Now().Add(ttl), + ExtKeyUsage: []cx509.ExtKeyUsage{cx509.ExtKeyUsageClientAuth, cx509.ExtKeyUsageServerAuth}, + } + + var ( + bits int + b []byte + ) + + if ca == nil { + bits = 4096 + crt.IsCA = true + crt.BasicConstraintsValid = true + crt.KeyUsage = cx509.KeyUsageDigitalSignature | cx509.KeyUsageCertSign + if len(cn) > 0 { + crt.Subject.CommonName = cn[0] + } + } else { + bits = 2048 + crt.KeyUsage = cx509.KeyUsageDigitalSignature + crt.PermittedDNSDomainsCritical = true + for i, s := range cn { + if i == 0 { + crt.Subject.CommonName = cn[0] + } + crt.DNSNames = append(crt.DNSNames, s) + } + } + + pk, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, errors.Wrapf(err, "generate private key") + } + + if ca == nil { + b, err = cx509.CreateCertificate(rand.Reader, crt, crt, &pk.PublicKey, pk) + } else { + block, _ := pem.Decode(ca.Public) + if block == nil { + return nil, errors.New("invalid decode public CA pem ") + } + var caCrt *cx509.Certificate + caCrt, err = cx509.ParseCertificate(block.Bytes) + if err != nil { + return nil, errors.Wrapf(err, "parse CA certificate") + } + + block, _ = pem.Decode(ca.Private) + if block == nil { + return nil, errors.New("invalid decode private CA pem ") + } + var caPK *rsa.PrivateKey + caPK, err = cx509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, errors.Wrapf(err, "decode CA private key") + } + + b, err = cx509.CreateCertificate(rand.Reader, crt, caCrt, &pk.PublicKey, caPK) + } + if err != nil { + return nil, errors.Wrapf(err, "generate certificate") + } + + var pubPEM bytes.Buffer + if err = pem.Encode(&pubPEM, &pem.Block{Type: "CERTIFICATE", Bytes: b}); err != nil { + return nil, errors.Wrapf(err, "encode public pem") + } + + var privPEM bytes.Buffer + if err = pem.Encode(&privPEM, + &pem.Block{Type: "RSA PRIVATE KEY", Bytes: cx509.MarshalPKCS1PrivateKey(pk)}); err != nil { + return nil, errors.Wrapf(err, "encode private pem") + } + + return &Cert{ + Public: pubPEM.Bytes(), + Private: privPEM.Bytes(), + }, nil +} + +func NewCertCA(c *Config, ttl time.Duration, cn string) (*Cert, error) { + return generate(c, ttl, 1, nil, cn) +} + +func NewCert(c *Config, ttl time.Duration, sn int64, ca *Cert, cn ...string) (*Cert, error) { + return generate(c, ttl, sn, ca, cn...) +} diff --git a/sdk/certificate/x509/x509_test.go b/sdk/certificate/x509/x509_test.go new file mode 100644 index 0000000..f228344 --- /dev/null +++ b/sdk/certificate/x509/x509_test.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package x509_test + +import ( + "testing" + "time" + + "github.com/osspkg/goppy/sdk/certificate/x509" +) + +func TestUnit_X509(t *testing.T) { + conf := &x509.Config{ + Organization: "Demo Inc.", + } + + crt, err := x509.NewCertCA(conf, time.Hour*24*365*10, "Demo Root R1") + if err != nil { + t.Fatalf(err.Error()) + } + t.Log(string(crt.Private), string(crt.Public)) + + crt, err = x509.NewCert(conf, time.Hour*24*90, 2, crt, "example.com", "*.example.com") + if err != nil { + t.Fatalf(err.Error()) + } + t.Log(string(crt.Private), string(crt.Public)) +} diff --git a/sdk/console/README.md b/sdk/console/README.md new file mode 100644 index 0000000..85ddd2b --- /dev/null +++ b/sdk/console/README.md @@ -0,0 +1,167 @@ +# Console application + +## Сreating console application + +```go +import "github.com/osspkg/goppy/sdk/console" + +// creating an instance of the application, +// specifying its name and description for flag: --help +root := console.New("tool", "help tool") +// adding root command +root.RootCommand(...) +// adding one or more commands +root.AddCommand(...) +// launching the app +root.Exec() +``` + +## Creating a simple command + +```go +import "github.com/osspkg/goppy/sdk/console" +// creating a new team with settings +console.NewCommand(func(setter console.CommandSetter) { + // passing the command name and description + setter.Setup("simple", "first-level command") + // description of the usage example + setter.Example("simple aa/bb/cc -a=hello -b=123 --cc=123.456 -e") + // description of flags + setter.Flag(func(f console.FlagsSetter) { + // you can specify the flag's name, default value, and information about the flag's value. + f.StringVar("a", "demo", "this is a string argument") + f.IntVar("b", 1, "this is a int64 argument") + f.FloatVar("cc", 1e-5, "this is a float64 argument") + f.Bool("d", "this is a bool argument") + }) + // argument validation: specifies the number of arguments, + // and validation function that should return + // value after validation and validation error + setter.ArgumentFunc(func(s []string) ([]string, error) { + if !strings.Contains(s[0], "/") { + return nil, fmt.Errorf("argument must contain /") + } + return strings.Split(s[0], "/"), nil + }) + // command execution function + // first argument is a slice of arguments from setter.Argument + // all subsequent arguments must be in the same order and types as listed in setter.Flag + setter.ExecFunc(func(args []string, a string, b int64, c float64, d bool) { + fmt.Println(args, a, b, c, d) + }) +}), +``` + +### example of execution results + +**go run main.go --help** +```text +Usage: + tool [command] [args] + +Available Commands: + simple first-level command + +_____________________________________________________ +Use flag --help for more information about a command. + +``` +**go run main.go simple --help** +```text +Usage: + tool simple [arg] -a=demo -b=1 --cc=1e-05 -d + +Flags: + -a this is a string argument (default: demo) + -b this is a int64 argument (default: 1) + --cc this is a float64 argument (default: 1e-05) + -d this is a bool argument (default: true) + + +Examples: + tool simple aa/bb/cc -a=hello -b=123 --cc=123.456 -e +``` + +## Creating multi-level command tree + +To create a multi-level command tree, +you need to add the child command to the parent via the `AddCommand` method. + +At the same time, in the parent command, it is enough to +specify only the name and description via the `Setup` method. + +```go +root := console.New("tool", "help tool") + +simpleCmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("simple", "third level") + .... +}) + +twoCmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("two", "second level") + setter.AddCommand(simpleCmd) +}) + +oneCmd := console.NewCommand(func(setter console.CommandSetter) { + setter.Setup("one", "first level") + setter.AddCommand(twoCmd) +}) + +root.AddCommand(oneCmd) +root.Exec() +``` + +### example of execution results + +**go run main.go --help** + +```text +Usage: + tool [command] [args] + +Available Commands: + one first level + +_____________________________________________________ +Use flag --help for more information about a command. +``` +**go run main.go one --help** + +```text +Usage: + tool one [command] [args] + +Available Commands: + two second level + +_____________________________________________________ +Use flag --help for more information about a command. +``` +**go run main.go one two --help** +```text +Usage: + tool one two [command] [args] + +Available Commands: + simple third level + +_____________________________________________________ +Use flag --help for more information about a command. +``` +**go run main.go one two simple --help** +```text +Usage: + tool one two simple [arg] -a=demo -b=1 --cc=1e-05 -d + +Flags: + -a this is a string argument (default: demo) + -b this is a int64 argument (default: 1) + --cc this is a float64 argument (default: 1e-05) + -d this is a bool argument (default: false) + + +Examples: + tool simple aa/bb/cc -a=hello -b=123 --cc=123.456 -e + +``` \ No newline at end of file diff --git a/sdk/console/args.go b/sdk/console/args.go new file mode 100644 index 0000000..8c0858b --- /dev/null +++ b/sdk/console/args.go @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package console + +import "strings" + +type ( + //ValidFunc validate argument interface + ValidFunc func([]string) ([]string, error) + //Argument model + Argument struct { + ValidFunc ValidFunc + } +) + +// NewArgument constructor +func NewArgument() *Argument { + return &Argument{} +} + +type ( + //Args list model + Args struct { + list []Arg + next []string + } + //Arg model + Arg struct { + Key string + Value string + } + //ArgGetter argument getter interface + ArgGetter interface { + Has(name string) bool + Get(name string) *string + } +) + +// NewArgs constructor +func NewArgs() *Args { + return &Args{ + list: make([]Arg, 0), + next: make([]string, 0), + } +} + +func (a *Args) Has(name string) bool { + for _, v := range a.list { + if v.Key == name { + return true + } + } + return false +} + +func (a *Args) Get(name string) *string { + for _, v := range a.list { + if v.Key == name { + return &v.Value + } + } + return nil +} + +func (a *Args) Next() []string { + return a.next +} + +func (a *Args) Parse(list []string) *Args { + for i := 0; i < len(list); i++ { + // args + if strings.HasPrefix(list[i], "-") { + arg := Arg{} + v := strings.TrimLeft(list[i], "-") + vs := strings.SplitN(v, "=", 2) + switch len(vs) { + case 1: + arg.Key, arg.Value = vs[0], "" + a.list = append(a.list, arg) + continue + case 2: + arg.Key, arg.Value = vs[0], vs[1] + a.list = append(a.list, arg) + continue + } + + if i+1 < len(list) && !strings.HasPrefix(list[i+1], "-") { + arg.Key, arg.Value = vs[0], list[i+1] + a.list = append(a.list, arg) + i++ + continue + } + + arg.Key = vs[0] + a.list = append(a.list, arg) + continue + } + //commands + a.next = append(a.next, list[i]) + } + + return a +} diff --git a/sdk/console/command.go b/sdk/console/command.go new file mode 100644 index 0000000..198b68d --- /dev/null +++ b/sdk/console/command.go @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package console + +import ( + "fmt" + "reflect" +) + +type Command struct { + root bool + name string + desc string + examples []string + flags *Flags + args *Argument + execute interface{} + + next []CommandGetter +} + +type CommandGetter interface { + Next(string) CommandGetter + List() []CommandGetter + Validate() error + Is(string) bool + Name() string + Description() string + Examples() []string + ArgCall(d []string) ([]string, error) + Flags() FlagsGetter + Call() interface{} + AddCommand(...CommandGetter) + AsRoot() CommandGetter + IsRoot() bool +} + +type CommandSetter interface { + Setup(string, string) + Example(string) + Flag(cb func(FlagsSetter)) + ArgumentFunc(call ValidFunc) + ExecFunc(interface{}) + AddCommand(...CommandGetter) +} + +func NewCommand(cb func(CommandSetter)) CommandGetter { + cmd := &Command{ + next: make([]CommandGetter, 0), + flags: NewFlags(), + args: NewArgument(), + examples: make([]string, 0), + } + cb(cmd) + return cmd +} + +func (c *Command) Setup(name, description string) { + c.name, c.desc = name, description +} + +func (c *Command) AsRoot() CommandGetter { + c.root = true + c.name = "" + return c +} + +func (c *Command) IsRoot() bool { + return c.root +} + +func (c *Command) Name() string { + return c.name +} + +func (c *Command) Description() string { + return c.desc +} + +func (c *Command) Examples() []string { + return c.examples +} + +func (c *Command) Example(s string) { + c.examples = append(c.examples, s) +} + +func (c *Command) Flag(cb func(FlagsSetter)) { + cb(c.flags) +} + +func (c *Command) Flags() FlagsGetter { + return c.flags +} + +func (c *Command) ArgumentFunc(call ValidFunc) { + c.args.ValidFunc = call +} + +func (c *Command) ArgCall(d []string) ([]string, error) { + if c.args.ValidFunc == nil { + return d, nil + } + return c.args.ValidFunc(d) +} + +func (c *Command) ExecFunc(i interface{}) { + c.execute = i +} + +func (c *Command) Next(cmd string) CommandGetter { + for _, getter := range c.next { + if getter.Is(cmd) { + return getter + } + } + return nil +} + +func (c *Command) List() []CommandGetter { + return c.next +} + +func (c *Command) Validate() error { + if len(c.name) == 0 && !c.IsRoot() { + return fmt.Errorf("command name is empty. use Setup(name, description)") + } + if reflect.ValueOf(c.execute).Kind() != reflect.Func { + return fmt.Errorf("command [%s] ExecFunc: is not a func", c.name) + } + count := c.flags.Count() + 1 + if reflect.ValueOf(c.execute).Type().NumIn() != count { + return fmt.Errorf("command [%s] Flags: fewer arguments declared than expected in ExecFunc", c.name) + } + return nil +} + +func (c *Command) Call() interface{} { + return c.execute +} + +func (c *Command) Is(s string) bool { + return c.name == s +} + +func (c *Command) AddCommand(getter ...CommandGetter) { + for _, v := range getter { + if err := v.Validate(); err != nil { + Fatalf(err.Error()) + } + c.next = append(c.next, v) + } +} diff --git a/sdk/console/console.go b/sdk/console/console.go new file mode 100644 index 0000000..8d269ca --- /dev/null +++ b/sdk/console/console.go @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package console + +import ( + "os" + "reflect" +) + +const helpArg = "help" + +type Console struct { + name string + description string + root CommandGetter +} + +func New(name, description string) *Console { + return &Console{ + name: name, + description: description, + root: NewCommand(func(_ CommandSetter) {}).AsRoot(), + } +} + +func (c *Console) recover() { + if d := recover(); d != nil { + Fatalf("%+v", d) + } +} + +func (c *Console) AddCommand(getter ...CommandGetter) { + defer c.recover() + + c.root.AddCommand(getter...) +} + +func (c *Console) RootCommand(getter CommandGetter) { + defer c.recover() + + next := c.root.List() + c.root = getter.AsRoot() + if err := c.root.Validate(); err != nil { + Fatalf(err.Error()) + } + c.root.AddCommand(next...) +} + +func (c *Console) Exec() { + defer c.recover() + + args := NewArgs().Parse(os.Args[1:]) + cmd, cur, h := c.build(args) + if h { + help(c.name, c.description, cmd, cur) + return + } + c.run(cmd, args.Next()[len(cur):], args) +} + +func (c *Console) build(args *Args) (CommandGetter, []string, bool) { + var ( + i int + cmd string + + command CommandGetter + cur []string + help bool + ) + for i, cmd = range args.Next() { + if i == 0 { + if nc := c.root.Next(cmd); nc != nil { + command = nc + continue + } + command = c.root + break + } else { + if nc := command.Next(cmd); nc != nil { + command = nc + continue + } + break + } + } + + if len(args.Next()) > 0 { + cur = args.Next()[:i] + } else { + command = c.root + } + + if args.Has(helpArg) { + help = true + } + + return command, cur, help +} + +func (c *Console) run(command CommandGetter, a []string, args *Args) { + rv := make([]reflect.Value, 0) + + if command == nil || command.Call() == nil { + Fatalf("command not found") + } + + val, err := command.ArgCall(a) + if err != nil { + Fatalf("command \"%s\" validate arguments: %s", command.Name(), err.Error()) + } + rv = append(rv, reflect.ValueOf(val)) + + err = command.Flags().Call(args, func(i interface{}) { + rv = append(rv, reflect.ValueOf(i)) + }) + if err != nil { + Fatalf("command \"%s\" validate flags: %s", command.Name(), err.Error()) + } + + if reflect.ValueOf(command.Call()).Type().NumIn() != len(rv) { + Fatalf("command \"%s\" Flags: fewer arguments declared than expected in ExecFunc", command.Name()) + } + + reflect.ValueOf(command.Call()).Call(rv) +} diff --git a/sdk/console/flags.go b/sdk/console/flags.go new file mode 100644 index 0000000..125a58b --- /dev/null +++ b/sdk/console/flags.go @@ -0,0 +1,184 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package console + +import ( + "fmt" + "strconv" +) + +type ( + //Flags model + Flags struct { + d []FlagItem + } + //FlagItem element of flag model + FlagItem struct { + req bool + name string + value interface{} + usage string + call func(getter ArgGetter) (interface{}, error) + } +) + +// FlagsGetter getter interface +type FlagsGetter interface { + Info(cb func(bool, string, interface{}, string)) + Call(g ArgGetter, cb func(interface{})) error +} + +// FlagsSetter setter interface +type FlagsSetter interface { + StringVar(name string, value string, usage string) + String(name string, usage string) + IntVar(name string, value int64, usage string) + Int(name string, usage string) + FloatVar(name string, value float64, usage string) + Float(name string, usage string) + Bool(name string, usage string) +} + +// NewFlags init new flag +func NewFlags() *Flags { + return &Flags{ + d: make([]FlagItem, 0), + } +} + +// Count of flags +func (f *Flags) Count() int { + return len(f.d) +} + +// Info about command +func (f *Flags) Info(cb func(req bool, name string, v interface{}, usage string)) { + for _, item := range f.d { + cb(item.req, item.name, item.value, item.usage) + } +} + +func (f *Flags) Call(g ArgGetter, cb func(interface{})) error { + for _, item := range f.d { + v, err := item.call(g) + if err != nil { + return err + } + cb(v) + } + return nil +} + +// StringVar flag decoder with default value +func (f *Flags) StringVar(name string, value string, usage string) { + f.d = append(f.d, FlagItem{ + req: false, + name: name, + value: value, + usage: usage, + call: func(getter ArgGetter) (interface{}, error) { + if val := getter.Get(name); val != nil { + return *val, nil + } + return value, nil + }, + }) +} + +// String flag decoder +func (f *Flags) String(name string, usage string) { + f.d = append(f.d, FlagItem{ + req: true, + name: name, + usage: usage, + call: func(getter ArgGetter) (interface{}, error) { + if val := getter.Get(name); val != nil && len(*val) > 0 { + return *val, nil + } + return nil, fmt.Errorf("--%s is not found", name) + }, + }) +} + +// IntVar flag decoder with default value +func (f *Flags) IntVar(name string, value int64, usage string) { + f.d = append(f.d, FlagItem{ + req: false, + value: value, + name: name, + usage: usage, + call: func(getter ArgGetter) (interface{}, error) { + if val := getter.Get(name); val != nil && len(*val) > 0 { + return strconv.ParseInt(*val, 10, 64) + } + return value, nil + }, + }) +} + +// Int flag decoder +func (f *Flags) Int(name string, usage string) { + f.d = append(f.d, FlagItem{ + req: true, + value: 0, + name: name, + usage: usage, + call: func(getter ArgGetter) (interface{}, error) { + if val := getter.Get(name); val != nil && len(*val) > 0 { + return strconv.ParseInt(*val, 10, 64) + } + return nil, fmt.Errorf("--%s is not found", name) + }, + }) +} + +// FloatVar flag decoder with default value +func (f *Flags) FloatVar(name string, value float64, usage string) { + f.d = append(f.d, FlagItem{ + req: false, + value: value, + name: name, + usage: usage, + call: func(getter ArgGetter) (interface{}, error) { + if val := getter.Get(name); val != nil && len(*val) > 0 { + return strconv.ParseFloat(*val, 64) + } + return value, nil + }, + }) +} + +// Float flag decoder +func (f *Flags) Float(name string, usage string) { + f.d = append(f.d, FlagItem{ + req: true, + value: 0.0, + name: name, + usage: usage, + call: func(getter ArgGetter) (interface{}, error) { + if val := getter.Get(name); val != nil && len(*val) > 0 { + return strconv.ParseFloat(*val, 64) + } + return nil, fmt.Errorf("--%s is not found", name) + }, + }) +} + +// Bool flag decoder +func (f *Flags) Bool(name string, usage string) { + f.d = append(f.d, FlagItem{ + req: false, + value: false, + name: name, + usage: usage, + call: func(getter ArgGetter) (interface{}, error) { + if getter.Has(name) { + return true, nil + } + return false, nil + }, + }) +} diff --git a/sdk/console/help.go b/sdk/console/help.go new file mode 100644 index 0000000..a08e05b --- /dev/null +++ b/sdk/console/help.go @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package console + +import ( + "fmt" + "os" + "sort" + "strings" + "text/template" +) + +var helpTemplate = `{{if len .Description | ne 0}}{{.Description}}{{end}} +{{if .ShowCommand}} +Current Command: + {{.Name}} {{.Curr}} {{.Args}} {{range $ex := .FlagsEx}} {{$ex}}{{end}} + +Flags: +{{range $ex := .Flags}} {{$ex}} +{{end}} +Examples: +{{range $ex := .Examples}} {{$ex}} +{{end}} +_____________________________________________________{{end}} +{{if len .Next | ne 0}} +Usage: + {{.Name}} {{.Curr}} [command] [args] + +Available Commands: +{{range $ex := .Next}} {{$ex}} +{{end}}{{end}} +_____________________________________________________ +Use flag --help for more information about a command. + +` + +type helpModel struct { + Name string + Description string + ShowCommand bool + + Args string + Examples []string + FlagsEx []string + Flags []string + + Curr string + Next []string +} + +func help(tool string, desc string, c CommandGetter, args []string) { + model := &helpModel{ + ShowCommand: c != nil && c.Call() != nil, + Name: tool, + Description: desc, + + Curr: strings.Join(args, " "), + Next: func() (out []string) { + if c == nil { + return + } + var max int + next := c.List() + for _, v := range next { + if max < len(v.Name()) { + max = len(v.Name()) + } + } + sort.Slice(next, func(i, j int) bool { + return next[i].Name() < next[j].Name() + }) + for _, v := range next { + out = append(out, v.Name()+strings.Repeat(" ", max-len(v.Name()))+" "+v.Description()) + } + + return + }(), + } + + if c != nil { + model.Examples = func() (out []string) { + for _, v := range c.Examples() { + out = append(out, tool+" "+v) + } + return + }() + model.Args = "[arg]" + model.Flags = func() (out []string) { + max := 0 + c.Flags().Info(func(r bool, n string, v interface{}, u string) { + if len(n) > max { + max = len(n) + } + }) + c.Flags().Info(func(r bool, n string, v interface{}, u string) { + ex, i := "", 1 + if !r { + ex = fmt.Sprintf("(default: %+v)", v) + } + if len(n) > 1 { + i = 2 + } + out = append(out, fmt.Sprintf( + "%s%s%s %s %s", + strings.Repeat("-", i), n, strings.Repeat(" ", max-len(n)), u, ex)) + }) + return + }() + model.FlagsEx = func() (out []string) { + c.Flags().Info(func(r bool, n string, v interface{}, u string) { + i, ex := 1, "" + if len(n) > 1 { + i = 2 + } + switch v.(type) { + case bool: + default: + ex = fmt.Sprintf("=%+v", v) + } + out = append(out, fmt.Sprintf( + "%s%s%s", + strings.Repeat("-", i), n, ex)) + }) + return + }() + } + + if err := template.Must(template.New("").Parse(helpTemplate)).Execute(os.Stdout, model); err != nil { + Fatalf(err.Error()) + } +} diff --git a/sdk/console/io.go b/sdk/console/io.go new file mode 100644 index 0000000..c05c996 --- /dev/null +++ b/sdk/console/io.go @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package console + +import ( + "bufio" + "fmt" + "os" + "strings" + "sync/atomic" + + "github.com/osspkg/goppy/sdk/errors" +) + +const ( + cRESET = "\u001B[0m" + cBLACK = "\u001B[30m" + cRED = "\u001B[31m" + cGREEN = "\u001B[32m" + cYELLOW = "\u001B[33m" + cBLUE = "\u001B[34m" + cPURPLE = "\u001B[35m" + cCYAN = "\u001B[36m" + + eof = "\n" +) + +var ( + scan *bufio.Scanner + yesNo = []string{"y", "n"} + debugLevel uint32 = 0 +) + +func init() { + scan = bufio.NewScanner(os.Stdin) +} + +func output(msg string, vars []string, def string) { + if len(def) > 0 { + def = fmt.Sprintf(" [%s]", def) + } + v := "" + if len(vars) > 0 { + v = fmt.Sprintf(" (%s)", strings.Join(vars, "/")) + } + Infof("%s%s%s: ", msg, v, def) +} + +// Input console input request +func Input(msg string, vars []string, def string) string { + output(msg, vars, def) + + for { + if scan.Scan() { + r := scan.Text() + if len(r) == 0 { + return def + } + if len(vars) == 0 { + return r + } + for _, v := range vars { + if v == r { + return r + } + } + output("Bad answer! Try again", vars, def) + } + } +} + +// InputBool console bool input request +func InputBool(msg string, def bool) bool { + v := "n" + if def { + v = "y" + } + v = Input(msg, yesNo, v) + return v == "y" +} + +func color(c, msg string, args []interface{}) { + fmt.Printf(c+msg+cRESET, args...) +} + +func colorln(c, msg string, args []interface{}) { + if !strings.HasSuffix(msg, eof) { + msg += eof + } + color(c, msg, args) +} + +// Rawf console message writer without level info +func Rawf(msg string, args ...interface{}) { + colorln(cRESET, msg, args) +} + +// Infof console message writer for info level +func Infof(msg string, args ...interface{}) { + colorln(cRESET, "[INF] "+msg, args) +} + +// Warnf console message writer for warning level +func Warnf(msg string, args ...interface{}) { + colorln(cYELLOW, "[WAR] "+msg, args) +} + +// Errorf console message writer for error level +func Errorf(msg string, args ...interface{}) { + colorln(cRED, "[ERR] "+msg, args) +} + +// ShowDebug init show debug +func ShowDebug(ok bool) { + var v uint32 = 0 + if ok { + v = 1 + } + atomic.StoreUint32(&debugLevel, v) +} + +// Debugf console message writer for debug level +func Debugf(msg string, args ...interface{}) { + if atomic.LoadUint32(&debugLevel) > 0 { + colorln(cBLUE, "[DEB] "+msg, args) + } +} + +// FatalIfErr console message writer if err is not nil +func FatalIfErr(err error, msg string, args ...interface{}) { + if err != nil { + Fatalf(errors.Wrapf(err, msg, args...).Error()) + } +} + +// Fatalf console message writer with exit code 1 +func Fatalf(msg string, args ...interface{}) { + colorln(cRED, "[ERR] "+msg, args) + os.Exit(1) +} diff --git a/sdk/context/contexts.go b/sdk/context/contexts.go new file mode 100644 index 0000000..4f8d689 --- /dev/null +++ b/sdk/context/contexts.go @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package context + +import ( + cc "context" + "reflect" +) + +func Combine(multi ...cc.Context) (cc.Context, cc.CancelFunc) { + ctx, cancel := cc.WithCancel(cc.Background()) + + go func() { + cases := make([]reflect.SelectCase, 0, len(multi)) + for _, vv := range multi { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(vv.Done()), + }) + } + chosen, _, _ := reflect.Select(cases) + switch chosen { + default: + cancel() + } + }() + + return ctx, cancel +} diff --git a/sdk/context/contexts_test.go b/sdk/context/contexts_test.go new file mode 100644 index 0000000..6b14d14 --- /dev/null +++ b/sdk/context/contexts_test.go @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package context_test + +import ( + ccc "context" + "fmt" + "testing" + "time" + + "github.com/osspkg/goppy/sdk/context" +) + +func TestUnit_Combine(t *testing.T) { + c, cancel := context.Combine(ccc.Background(), ccc.Background()) + if c == nil { + t.Fatalf("contexts.Combine returned nil") + } + + select { + case <-c.Done(): + t.Fatalf("<-c.Done() == it should block") + default: + } + + cancel() + <-time.After(time.Second) + + select { + case <-c.Done(): + default: + t.Fatalf("<-c.Done() it shouldn't block") + } + + if got, want := fmt.Sprint(c), "context.Background.WithCancel"; got != want { + t.Fatalf("contexts.Combine() = %q want %q", got, want) + } +} diff --git a/sdk/domain/domain.go b/sdk/domain/domain.go new file mode 100644 index 0000000..ea6960e --- /dev/null +++ b/sdk/domain/domain.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package domain + +var dot = byte('.') + +func Level(s string, level int) string { + max := len(s) - 1 + count, pos := 0, 0 + if s[max] == dot { + max-- + } + + for i := max; i >= 0; i-- { + if s[i] == dot { + count++ + if count == level { + pos = i + 1 + break + } + } + } + return s[pos:] +} diff --git a/sdk/domain/domain_test.go b/sdk/domain/domain_test.go new file mode 100644 index 0000000..9218b9e --- /dev/null +++ b/sdk/domain/domain_test.go @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package domain_test + +import ( + "fmt" + "testing" + + "github.com/osspkg/goppy/sdk/domain" +) + +func TestUnit_Level(t *testing.T) { + type args struct { + s string + level int + } + tests := []struct { + args args + want string + }{ + { + args: args{ + s: "www.domain.ltd", + level: 1, + }, + want: "ltd", + }, + { + args: args{ + s: "www.domain.ltd", + level: 2, + }, + want: "domain.ltd", + }, + { + args: args{ + s: "www.domain.ltd", + level: 10, + }, + want: "www.domain.ltd", + }, + { + args: args{ + s: "www.domain.ltd.", + level: 1, + }, + want: "ltd.", + }, + } + for i, tt := range tests { + t.Run(fmt.Sprintf("Case %d", i), func(t *testing.T) { + if got := domain.Level(tt.args.s, tt.args.level); got != tt.want { + t.Errorf("DomainLevel() = %v, want %v", got, tt.want) + } + }) + } +} + +func Benchmark_Level(b *testing.B) { + d := "www.domain.ltd." + e := "domain.ltd." + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if got := domain.Level(d, 2); got != e { + b.Errorf("Level() = %v, want %v", got, e) + } + } +} diff --git a/sdk/encryption/aesgcm/aesgcm.go b/sdk/encryption/aesgcm/aesgcm.go new file mode 100644 index 0000000..163b015 --- /dev/null +++ b/sdk/encryption/aesgcm/aesgcm.go @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package aesgcm + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "fmt" + "io" +) + +const keySize = 32 + +type Codec struct { + key []byte + block cipher.Block +} + +func New(key string) (*Codec, error) { + kb := []byte(key) + if len(kb) < keySize { + return nil, fmt.Errorf("invalid key len") + } + obj := &Codec{ + key: kb[:keySize], + } + block, err := aes.NewCipher(obj.key) + if err != nil { + return nil, err + } + obj.block = block + return obj, nil +} + +func (v *Codec) Encrypt(plaintext []byte) ([]byte, error) { + gcm, err := cipher.NewGCM(v.block) + if err != nil { + return nil, err + } + nonce := make([]byte, gcm.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + ciphertext := gcm.Seal(nonce, nonce, plaintext, nil) + return ciphertext, nil +} + +func (v *Codec) Decrypt(ciphertext []byte) ([]byte, error) { + gcm, err := cipher.NewGCM(v.block) + if err != nil { + return nil, err + } + nonceSize := gcm.NonceSize() + if len(ciphertext) < nonceSize { + return nil, fmt.Errorf("invalid message len") + } + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return nil, err + } + return plaintext, nil +} diff --git a/sdk/encryption/aesgcm/aesgcm_test.go b/sdk/encryption/aesgcm/aesgcm_test.go new file mode 100644 index 0000000..aba1041 --- /dev/null +++ b/sdk/encryption/aesgcm/aesgcm_test.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package aesgcm_test + +import ( + "testing" + + "github.com/osspkg/goppy/sdk/encryption/aesgcm" + "github.com/osspkg/goppy/sdk/random" + "github.com/stretchr/testify/require" +) + +func TestUnit_Codec(t *testing.T) { + rndKey := random.String(32) + message := []byte("Hello World!") + + c, err := aesgcm.New(rndKey) + require.NoError(t, err) + + enc1, err := c.Encrypt(message) + require.NoError(t, err) + + dec1, err := c.Decrypt(enc1) + require.NoError(t, err) + + require.Equal(t, message, dec1) + + c, err = aesgcm.New(rndKey) + require.NoError(t, err) + + enc2, err := c.Encrypt(message) + require.NoError(t, err) + + require.NotEqual(t, enc1, enc2) + + dec2, err := c.Decrypt(enc1) + require.NoError(t, err) + + require.Equal(t, message, dec2) + +} diff --git a/sdk/env/env.go b/sdk/env/env.go new file mode 100644 index 0000000..4eef6dc --- /dev/null +++ b/sdk/env/env.go @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package env + +import "os" + +func Get(key, def string) string { + v := os.Getenv(key) + if len(v) == 0 { + return def + } + return v +} diff --git a/sdk/errors/errors.go b/sdk/errors/errors.go new file mode 100644 index 0000000..1cca9ec --- /dev/null +++ b/sdk/errors/errors.go @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package errors + +import ( + e "errors" + "fmt" +) + +type err struct { + cause error + message string + trace string +} + +func New(message string) error { + return &err{message: message} +} + +func (v *err) Error() string { + switch true { + case len(v.message) > 0 && v.cause != nil: + return v.message + ": " + v.cause.Error() + v.trace + case v.cause != nil: + return v.cause.Error() + v.trace + } + return v.message + v.trace +} + +func (v *err) Cause() error { + return v.cause +} + +func (v *err) Unwrap() error { + return v.cause +} + +func (v *err) WithTrace() { + v.trace = tracing() +} + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +func Trace(cause error, message string, args ...interface{}) error { + v := Wrapf(cause, message, args...) + //nolint: errorlint + if vv, ok := v.(*err); ok { + vv.WithTrace() + return vv + } + return v +} + +func Wrapf(cause error, message string, args ...interface{}) error { + if cause == nil { + return nil + } + var err0 *err + if len(args) == 0 { + err0 = &err{ + cause: cause, + message: message, + } + } else { + err0 = &err{ + cause: cause, + message: fmt.Sprintf(message, args...), + } + } + return err0 +} + +func Wrap(msg ...error) error { + if len(msg) == 0 { + return nil + } + var err0 error + for _, v := range msg { + if v == nil { + continue + } + if err0 == nil { + err0 = &err{cause: v} + continue + } + err0 = &err{ + cause: v, + message: err0.Error(), + } + } + return err0 +} + +func Unwrap(err error) error { + //nolint: errorlint + if v, ok := err.(interface { + Unwrap() error + }); ok { + return v.Unwrap() + } + return nil +} + +func Cause(err error) error { + for err != nil { + //nolint: errorlint + v, ok := err.(interface { + Cause() error + }) + if !ok { + return err + } + err = v.Cause() + } + + return nil +} + +func Is(err, target error) bool { + return e.Is(err, target) +} diff --git a/sdk/errors/errors_test.go b/sdk/errors/errors_test.go new file mode 100644 index 0000000..5e9822a --- /dev/null +++ b/sdk/errors/errors_test.go @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package errors_test + +import ( + e "errors" + "strings" + "testing" + + "github.com/osspkg/goppy/sdk/errors" +) + +func TestUnit_New(t *testing.T) { + type args struct { + message string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {name: "Case1", args: args{message: "hello"}, want: "hello", wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := errors.New(tt.args.message) + if (err != nil) != tt.wantErr { + t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) + return + } + if err.Error() != tt.want { + t.Errorf("New() error = %v, want %v", err.Error(), tt.want) + return + } + }) + } +} + +func TestUnit_Wrap(t *testing.T) { + type args struct { + msg []error + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "Case1", + args: args{msg: nil}, + want: "", + wantErr: false, + }, + { + name: "Case2", + args: args{msg: []error{errors.New("hello"), e.New("world")}}, + want: "hello: world", + wantErr: true, + }, + { + name: "Case3", + args: args{msg: []error{errors.New("err1"), e.New("err2"), nil, e.New("err3")}}, + want: "err1: err2: err3", + wantErr: true, + }, + { + name: "Case4", + args: args{msg: []error{errors.Wrapf(errors.New("err1"), "err1 message"), + errors.Wrapf(e.New("err2"), "err2 message"), + errors.Wrapf(e.New("err3"), "err3 message")}}, + want: "err1 message: err1: err2 message: err2: err3 message: err3", + wantErr: true, + }, + { + name: "Case5", + args: args{msg: []error{nil, nil, nil}}, + want: "", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := errors.Wrap(tt.args.msg...) + if (err != nil) != tt.wantErr { + t.Errorf("Wrap() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err.Error() != tt.want { + t.Errorf("Wrap() error = %v, want %v", err.Error(), tt.want) + return + } + }) + } +} + +func TestUnit_WrapMessage(t *testing.T) { + type args struct { + cause error + message string + args []interface{} + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + { + name: "Case1", + args: args{ + cause: nil, + message: "err context", + args: nil, + }, + want: "", + wantErr: false, + }, + { + name: "Case2", + args: args{ + cause: e.New("err1"), + message: "err context", + args: nil, + }, + want: "err context: err1", + wantErr: true, + }, + { + name: "Case3", + args: args{ + cause: e.New("err1"), + message: "bad ip %s", + args: []interface{}{"127.0.0.1"}, + }, + want: "bad ip 127.0.0.1: err1", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := errors.Wrapf(tt.args.cause, tt.args.message, tt.args.args...) + if (err != nil) != tt.wantErr { + t.Errorf("Wrapf() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err.Error() != tt.want { + t.Errorf("Wrapf() error = %v, want %v", err.Error(), tt.want) + return + } + }) + } +} + +func TestUnit_CauseUnwrap(t *testing.T) { + type fields struct { + cause error + message string + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + { + name: "Case1", + fields: fields{ + cause: e.New("err1"), + message: "context", + }, + want: "err1", + wantErr: true, + }, + { + name: "Case2", + fields: fields{ + cause: nil, + message: "context", + }, + want: "err1", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := errors.Wrapf(tt.fields.cause, tt.fields.message) + err := errors.Cause(v) + if (err != nil) != tt.wantErr { + t.Errorf("Cause() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err.Error() != tt.want { + t.Errorf("Cause() error = %v, want %v", err.Error(), tt.want) + return + } + err = errors.Unwrap(v) + if (err != nil) != tt.wantErr { + t.Errorf("Unwrap() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && err.Error() != tt.want { + t.Errorf("Unwrap() error = %v, want %v", err.Error(), tt.want) + return + } + }) + } +} + +func TestUnit_Is(t *testing.T) { + err0 := errors.New("test") + type args struct { + err error + target error + } + tests := []struct { + name string + args args + want bool + }{ + {name: "Case1", args: args{err: err0, target: err0}, want: true}, + {name: "Case2", args: args{err: errors.Wrapf(err0, "ttt"), target: err0}, want: true}, + {name: "Case3", args: args{err: errors.New("hello"), target: err0}, want: false}, + {name: "Case4", args: args{err: nil, target: err0}, want: false}, + {name: "Case5", args: args{err: errors.New("hello"), target: nil}, want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := errors.Is(tt.args.err, tt.args.target); got != tt.want { + t.Errorf("Is() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnit_Trace(t *testing.T) { + tests := []struct { + name string + err error + want string + }{ + { + name: "Case1", + err: errors.New("test"), + want: "[trace] github.com/osspkg/goppy/sdk/errors_test.TestUnit_Trace.func1", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := errors.Trace(tt.err, "msg"); got != nil && !strings.Contains(got.Error(), tt.want) { + t.Errorf("Trace() = %v, want %v", got.Error(), tt.want) + } + }) + } +} diff --git a/sdk/errors/trace.go b/sdk/errors/trace.go new file mode 100644 index 0000000..8e537fa --- /dev/null +++ b/sdk/errors/trace.go @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package errors + +import ( + "fmt" + "runtime" +) + +func tracing() string { + var list [10]uintptr + + n := runtime.Callers(4, list[:]) + frame := runtime.CallersFrames(list[:n]) + + result := "" + for { + v, ok := frame.Next() + if !ok { + break + } + result += fmt.Sprintf("\n\t[trace] %s:%d", v.Function, v.Line) + } + return result +} diff --git a/sdk/iofile/encdec.go b/sdk/iofile/encdec.go new file mode 100644 index 0000000..77218ec --- /dev/null +++ b/sdk/iofile/encdec.go @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iofile + +import ( + "encoding/json" + "os" + "path/filepath" + "sync" + + "github.com/osspkg/goppy/sdk/errors" + "gopkg.in/yaml.v3" +) + +var ( + errBadFileFormat = errors.New("format is not a supported") + + fileCodec = newCodec(). + Add(".yml", yaml.Marshal, yaml.Unmarshal). + Add(".yaml", yaml.Marshal, yaml.Unmarshal). + Add(".json", json.Marshal, json.Unmarshal) +) + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type codec struct { + enc map[string]func(v interface{}) ([]byte, error) + dec map[string]func([]byte, interface{}) error + mux sync.RWMutex +} + +func newCodec() *codec { + return &codec{ + enc: make(map[string]func(v interface{}) ([]byte, error), 10), + dec: make(map[string]func([]byte, interface{}) error, 10), + mux: sync.RWMutex{}, + } +} + +func AddFileCodec(ext string, enc func(v interface{}) ([]byte, error), dec func([]byte, interface{}) error) { + fileCodec.Add(ext, enc, dec) +} + +func (v *codec) Add(ext string, enc func(v interface{}) ([]byte, error), dec func([]byte, interface{}) error) *codec { + v.mux.Lock() + defer v.mux.Unlock() + v.enc[ext] = enc + v.dec[ext] = dec + return v +} + +func (v *codec) GetEnc(ext string) (func(v interface{}) ([]byte, error), bool) { + v.mux.RLock() + defer v.mux.RUnlock() + fn, ok := v.enc[ext] + return fn, ok +} + +func (v *codec) GetDec(ext string) (func([]byte, interface{}) error, bool) { + v.mux.RLock() + defer v.mux.RUnlock() + fn, ok := v.dec[ext] + return fn, ok +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type FileCodec string + +func (v FileCodec) Decode(configs ...interface{}) error { + data, err := os.ReadFile(string(v)) + if err != nil { + return err + } + ext := filepath.Ext(string(v)) + c, ok := fileCodec.GetDec(ext) + if !ok { + return errBadFileFormat + } + return v.dec(data, c, configs...) +} + +func (v FileCodec) Encode(configs ...interface{}) error { + ext := filepath.Ext(string(v)) + c, ok := fileCodec.GetEnc(ext) + if !ok { + return errBadFileFormat + } + b, err := v.enc(c, configs...) + if err != nil { + return err + } + return os.WriteFile(string(v), b, 0755) +} + +func (v FileCodec) dec(data []byte, call func([]byte, interface{}) error, configs ...interface{}) error { + for _, conf := range configs { + if err := call(data, conf); err != nil { + return err + } + } + return nil +} + +func (v FileCodec) enc(call func(v interface{}) ([]byte, error), configs ...interface{}) ([]byte, error) { + b := make([]byte, 0, 300*len(configs)) + for _, conf := range configs { + bb, err := call(conf) + if err != nil { + return nil, err + } + b = append(b, '\n', '\n') + b = append(b, bb...) + } + return b, nil +} diff --git a/sdk/iofile/encdec_test.go b/sdk/iofile/encdec_test.go new file mode 100644 index 0000000..6f3b119 --- /dev/null +++ b/sdk/iofile/encdec_test.go @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iofile_test + +import ( + "os" + "testing" + + "github.com/osspkg/goppy/sdk/iofile" + "github.com/stretchr/testify/require" +) + +func TestFile_EncodeDecode(t *testing.T) { + type TestDataItem1 struct { + AA string `yaml:"aa"` + BB bool `yaml:"bb"` + } + type TestData1 struct { + Data1 TestDataItem1 `yaml:"data-1"` + } + type TestDataItem2 struct { + CC string `yaml:"cc"` + DD int `yaml:"dd"` + } + type TestData2 struct { + Data2 TestDataItem2 `yaml:"data-2"` + } + + os.Remove("/tmp/bdsbdnsabkjlfadlksjfbkljd.yaml") + + model1 := &TestData1{Data1: TestDataItem1{AA: "123", BB: true}} + model2 := &TestData2{Data2: TestDataItem2{CC: "qwer", DD: -100}} + + err := iofile.FileCodec("/tmp/bdsbdnsabkjlfadlksjfbkljd.yaml").Encode(model1, model2) + require.NoError(t, err) + + model11 := &TestData1{} + model22 := &TestData2{} + + err = iofile.FileCodec("/tmp/bdsbdnsabkjlfadlksjfbkljd.yaml").Decode(model11, model22) + require.NoError(t, err) + + require.Equal(t, model1, model11) + require.Equal(t, model2, model22) +} diff --git a/sdk/iofile/files.go b/sdk/iofile/files.go new file mode 100644 index 0000000..d2e9ef0 --- /dev/null +++ b/sdk/iofile/files.go @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iofile + +import ( + "io" + "io/fs" + "os" + "path/filepath" + "strings" +) + +func Exist(filename string) bool { + _, err := os.Stat(filename) + return err == nil +} + +func Search(dir, filename string) ([]string, error) { + files := make([]string, 0) + err := filepath.Walk(dir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() || info.Name() != filename { + return nil + } + files = append(files, path) + return nil + }) + return files, err +} + +func Rewrite(filename string, call func([]byte) ([]byte, error)) error { + if !Exist(filename) { + if err := os.WriteFile(filename, []byte(""), 0755); err != nil { + return err + } + } + b, err := os.ReadFile(filename) + if err != nil { + return err + } + if b, err = call(b); err != nil { + return err + } + return os.WriteFile(filename, b, 0755) +} + +func Copy(dst, src string, mode os.FileMode) error { + source, err := os.OpenFile(src, os.O_RDONLY, 0) + if err != nil { + return err + } + defer source.Close() //nolint: errcheck + + if mode == 0 { + fi, err0 := source.Stat() + if err0 != nil { + return err0 + } + mode = fi.Mode() + } + + dist, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE|os.O_TRUNC, mode) + if err != nil { + return err + } + defer dist.Close() //nolint: errcheck + + _, err = io.Copy(dist, source) + return err +} + +func Folder(filename string) string { + dir := filepath.Dir(filename) + tree := strings.Split(dir, string(os.PathSeparator)) + return tree[len(tree)-1] +} diff --git a/sdk/iofile/hash.go b/sdk/iofile/hash.go new file mode 100644 index 0000000..7edd892 --- /dev/null +++ b/sdk/iofile/hash.go @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iofile + +import ( + "encoding/hex" + "fmt" + "hash" + "io" + "os" + + "github.com/osspkg/goppy/sdk/errors" +) + +func IsValidHash(filename string, h hash.Hash, valid string) error { + r, err := os.Open(filename) + if err != nil { + return err + } + if _, err = io.Copy(h, r); err != nil { + return errors.Wrapf(err, "calculate file hash") + } + result := hex.EncodeToString(h.Sum(nil)) + h.Reset() + if result != valid { + return fmt.Errorf("invalid hash: expected[%s] actual[%s]", valid, result) + } + return nil +} + +func Hash(filename string, h hash.Hash) (string, error) { + r, err := os.Open(filename) + if err != nil { + return "", err + } + if _, err = io.Copy(h, r); err != nil { + return "", errors.Wrapf(err, "calculate file hash") + } + result := hex.EncodeToString(h.Sum(nil)) + h.Reset() + return result, nil +} diff --git a/sdk/iosync/group.go b/sdk/iosync/group.go new file mode 100644 index 0000000..dc0d868 --- /dev/null +++ b/sdk/iosync/group.go @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iosync + +import "sync" + +type ( + Group interface { + Wait() + Background(call func()) + Run(call func()) + } + + _group struct { + wg sync.WaitGroup + sync Switch + } +) + +func NewGroup() Group { + return &_group{ + sync: NewSwitch(), + } +} + +func (v *_group) Wait() { + v.sync.On() + v.wg.Wait() + v.sync.Off() +} + +func (v *_group) Background(call func()) { + if v.sync.IsOn() { + return + } + v.wg.Add(1) + go func() { + call() + v.wg.Done() + }() +} + +func (v *_group) Run(call func()) { + if v.sync.IsOn() { + return + } + v.wg.Add(1) + call() + v.wg.Done() +} diff --git a/sdk/iosync/locker.go b/sdk/iosync/locker.go new file mode 100644 index 0000000..4f179ba --- /dev/null +++ b/sdk/iosync/locker.go @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iosync + +import "sync" + +type ( + Lock interface { + RLock(call func()) + Lock(call func()) + } + _lock struct { + mux sync.RWMutex + } +) + +func NewLock() Lock { + return &_lock{} +} + +func (v *_lock) Lock(call func()) { + v.mux.Lock() + call() + v.mux.Unlock() +} +func (v *_lock) RLock(call func()) { + v.mux.RLock() + call() + v.mux.RUnlock() +} diff --git a/sdk/iosync/switcher.go b/sdk/iosync/switcher.go new file mode 100644 index 0000000..4432ed4 --- /dev/null +++ b/sdk/iosync/switcher.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iosync + +import "sync/atomic" + +const ( + on uint64 = 1 + off uint64 = 0 +) + +type ( + Switch interface { + On() bool + Off() bool + IsOn() bool + IsOff() bool + } + + _switch struct { + i uint64 + } +) + +func NewSwitch() Switch { + return &_switch{i: 0} +} + +func (v *_switch) On() bool { + return atomic.CompareAndSwapUint64(&v.i, off, on) +} + +func (v *_switch) Off() bool { + return atomic.CompareAndSwapUint64(&v.i, on, off) +} + +func (v *_switch) IsOn() bool { + return atomic.LoadUint64(&v.i) == on +} + +func (v *_switch) IsOff() bool { + return atomic.LoadUint64(&v.i) == off +} diff --git a/sdk/iosync/switcher_test.go b/sdk/iosync/switcher_test.go new file mode 100644 index 0000000..a02ff1e --- /dev/null +++ b/sdk/iosync/switcher_test.go @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package iosync_test + +import ( + "testing" + + "github.com/osspkg/goppy/sdk/iosync" + "github.com/stretchr/testify/require" +) + +func TestNewSwitch(t *testing.T) { + sync := iosync.NewSwitch() + + require.False(t, sync.IsOn()) + require.True(t, sync.IsOff()) + + require.True(t, sync.On()) + require.False(t, sync.On()) + + require.False(t, sync.IsOff()) + require.True(t, sync.IsOn()) + +} diff --git a/sdk/ioutil/ioutils.go b/sdk/ioutil/ioutils.go new file mode 100644 index 0000000..f129fb6 --- /dev/null +++ b/sdk/ioutil/ioutils.go @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package ioutil + +import ( + "bytes" + "io" + + "github.com/osspkg/goppy/sdk/errors" +) + +func ReadAll(r io.ReadCloser) ([]byte, error) { + b, err := io.ReadAll(r) + err = errors.Wrap(err, r.Close()) + if err != nil { + return nil, err + } + return b, nil +} + +func ReadBytes(v io.Reader, divide string) ([]byte, error) { + var ( + n int + err error + b = make([]byte, 0, 512) + db = []byte(divide) + dl = len(db) + ) + + for { + if len(b) == cap(b) { + b = append(b, 0)[:len(b)] + } + n, err = v.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if err == io.EOF { + break + } + return nil, err + } + if len(b) < dl { + return b, io.EOF + } + if bytes.Equal(db, b[len(b)-dl:]) { + b = b[:len(b)-dl] + break + } + } + return b, nil +} + +func WriteBytes(v io.Writer, b []byte, divide string) error { + var ( + db = []byte(divide) + dl = len(db) + ) + if len(b) < dl || !bytes.Equal(db, b[len(b)-dl:]) { + b = append(b, db...) + } + if _, err := v.Write(b); err != nil { + return err + } + return nil +} diff --git a/sdk/ioutil/ioutils_test.go b/sdk/ioutil/ioutils_test.go new file mode 100644 index 0000000..036e900 --- /dev/null +++ b/sdk/ioutil/ioutils_test.go @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package ioutil_test + +import ( + "bytes" + "io" + "reflect" + "testing" + + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/ioutil" +) + +type mockReadCloser struct { + Data *bytes.Buffer + ErrRead error + ErrClose error +} + +func (v *mockReadCloser) Read(p []byte) (int, error) { + if v.ErrRead != nil { + return 0, v.ErrRead + } + return v.Data.Read(p) +} + +func (v *mockReadCloser) Close() error { + return v.ErrClose +} + +func TestUnit_ReadAll(t *testing.T) { + type args struct { + r io.ReadCloser + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + { + name: "Case1", + args: args{ + r: &mockReadCloser{ + Data: bytes.NewBuffer([]byte(`hello`)), + ErrRead: nil, + ErrClose: nil, + }, + }, + want: []byte(`hello`), + wantErr: false, + }, + { + name: "Case2", + args: args{ + r: &mockReadCloser{ + Data: nil, + ErrRead: errors.New("read error"), + ErrClose: nil, + }, + }, + want: nil, + wantErr: true, + }, + { + name: "Case3", + args: args{ + r: &mockReadCloser{ + Data: nil, + ErrRead: errors.New("read error"), + ErrClose: errors.New("close error"), + }, + }, + want: nil, + wantErr: true, + }, + { + name: "Case4", + args: args{ + r: &mockReadCloser{ + Data: bytes.NewBuffer([]byte(`hello`)), + ErrRead: nil, + ErrClose: errors.New("close error"), + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ioutil.ReadAll(tt.args.r) + if (err != nil) != tt.wantErr { + t.Errorf("ReadAll() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ReadAll() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sdk/log/common.go b/sdk/log/common.go new file mode 100644 index 0000000..1de020a --- /dev/null +++ b/sdk/log/common.go @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package log + +import "io" + +const ( + levelFatal uint32 = iota + LevelError + LevelWarn + LevelInfo + LevelDebug +) + +var levels = map[uint32]string{ + levelFatal: "FAT", + LevelError: "ERR", + LevelWarn: "WRN", + LevelInfo: "INF", + LevelDebug: "DBG", +} + +type Fields map[string]interface{} + +type Sender interface { + PutEntity(v *entity) + SendMessage(level uint32, call func(v *message)) + Close() +} + +// Writer interface +type Writer interface { + Fatalf(format string, args ...interface{}) + Errorf(format string, args ...interface{}) + Warnf(format string, args ...interface{}) + Infof(format string, args ...interface{}) + Debugf(format string, args ...interface{}) +} + +type WriterContext interface { + WithError(key string, err error) Writer + WithField(key string, value interface{}) Writer + WithFields(Fields) Writer + Writer +} + +// Logger base interface +type Logger interface { + SetOutput(out io.Writer) + SetLevel(v uint32) + GetLevel() uint32 + Close() + + WriterContext +} diff --git a/sdk/log/default.go b/sdk/log/default.go new file mode 100644 index 0000000..c799529 --- /dev/null +++ b/sdk/log/default.go @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package log + +import "io" + +var std = New() + +// Default logger +func Default() Logger { + return std +} + +// SetOutput change writer +func SetOutput(out io.Writer) { + std.SetOutput(out) +} + +// SetLevel change log level +func SetLevel(v uint32) { + std.SetLevel(v) +} + +// GetLevel getting log level +func GetLevel() uint32 { + return std.GetLevel() +} + +// Close waiting for all messages to finish recording +func Close() { + std.Close() +} + +// Infof info message +func Infof(format string, args ...interface{}) { + std.Infof(format, args...) +} + +// Warnf warning message +func Warnf(format string, args ...interface{}) { + std.Warnf(format, args...) +} + +// Errorf error message +func Errorf(format string, args ...interface{}) { + std.Errorf(format, args...) +} + +// Debugf debug message +func Debugf(format string, args ...interface{}) { + std.Debugf(format, args...) +} + +// Fatalf fatal message and exit +func Fatalf(format string, args ...interface{}) { + std.Fatalf(format, args...) +} + +// WithFields setter context to log message +func WithFields(v Fields) Writer { + return std.WithFields(v) +} + +// WithError setter context to log message +func WithError(key string, err error) Writer { + return std.WithError(key, err) +} + +// WithField setter context to log message +func WithField(key string, value interface{}) Writer { + return std.WithField(key, value) +} diff --git a/sdk/log/entity.go b/sdk/log/entity.go new file mode 100644 index 0000000..d80e265 --- /dev/null +++ b/sdk/log/entity.go @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package log + +import ( + "fmt" + "os" + "reflect" +) + +type entity struct { + log Sender + ctx Fields +} + +func newEntity(log Sender) *entity { + return &entity{ + log: log, + ctx: Fields{}, + } +} + +func (e *entity) Reset() { + e.ctx = Fields{} +} + +func (e *entity) WithError(key string, err error) Writer { + if err != nil { + e.ctx[key] = err.Error() + } else { + e.ctx[key] = nil + } + return e +} + +func (e *entity) WithField(key string, value interface{}) Writer { + ref := reflect.TypeOf(value) + if ref != nil { + switch ref.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Ptr, reflect.Struct: + e.ctx[key] = fmt.Sprintf("unsupported field value: %#v", value) + return e + } + } + e.ctx[key] = value + return e +} + +func (e *entity) WithFields(fields Fields) Writer { + for key, value := range fields { + ref := reflect.TypeOf(value) + if ref != nil { + switch ref.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Ptr, reflect.Struct: + e.ctx[key] = fmt.Sprintf("unsupported field value: %#v", value) + continue + } + } + e.ctx[key] = value + } + return e +} + +func (e *entity) prepareMessage(format string, args ...interface{}) func(v *message) { + return func(v *message) { + v.Message = fmt.Sprintf(format, args...) + for key, value := range e.ctx { + v.Ctx[key] = value + } + e.log.PutEntity(e) + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Infof info message +func (e *entity) Infof(format string, args ...interface{}) { + e.log.SendMessage(LevelInfo, e.prepareMessage(format, args...)) +} + +// Warnf warning message +func (e *entity) Warnf(format string, args ...interface{}) { + e.log.SendMessage(LevelWarn, e.prepareMessage(format, args...)) +} + +// Errorf error message +func (e *entity) Errorf(format string, args ...interface{}) { + e.log.SendMessage(LevelError, e.prepareMessage(format, args...)) +} + +// Debugf debug message +func (e *entity) Debugf(format string, args ...interface{}) { + e.log.SendMessage(LevelDebug, e.prepareMessage(format, args...)) +} + +// Fatalf fatal message and exit +func (e *entity) Fatalf(format string, args ...interface{}) { + e.log.SendMessage(levelFatal, e.prepareMessage(format, args...)) + e.log.Close() + os.Exit(1) +} diff --git a/sdk/log/logger.go b/sdk/log/logger.go new file mode 100644 index 0000000..80249d2 --- /dev/null +++ b/sdk/log/logger.go @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package log + +import ( + "encoding/json" + "io" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/osspkg/goppy/sdk/iosync" +) + +var nl = byte('\n') + +// log base model +type log struct { + status uint32 + writer io.Writer + entities sync.Pool + channel chan []byte + wg iosync.Group +} + +// New init new logger +func New() Logger { + object := &log{ + status: LevelError, + writer: os.Stdout, + channel: make(chan []byte, 1024), + wg: iosync.NewGroup(), + } + object.entities = sync.Pool{ + New: func() interface{} { + return newEntity(object) + }, + } + object.wg.Background(func() { + object.queue() + }) + return object +} + +func (l *log) SendMessage(level uint32, call func(v *message)) { + if l.GetLevel() < level { + return + } + + m, ok := poolMessage.Get().(*message) + if !ok { + m = &message{} + } + + call(m) + lvl, ok := levels[level] + if !ok { + lvl = "UNK" + } + m.Level, m.Time = lvl, time.Now().Unix() + + b, err := json.Marshal(m) + if err != nil { + b = []byte(err.Error()) + } + + select { + case l.channel <- b: + default: + } + + m.Reset() + poolMessage.Put(m) +} + +func (l *log) queue() { + for { + m, ok := <-l.channel + if !ok { + return + } + if m == nil { + return + } + l.writer.Write(append(m, nl)) //nolint:errcheck + } +} + +func (l *log) getEntity() *entity { + lw, ok := l.entities.Get().(*entity) + if !ok { + lw = newEntity(l) + } + return lw +} + +func (l *log) PutEntity(v *entity) { + v.Reset() + l.entities.Put(v) +} + +// Close waiting for all messages to finish recording +func (l *log) Close() { + l.channel <- nil + l.wg.Wait() +} + +// SetOutput change writer +func (l *log) SetOutput(out io.Writer) { + l.writer = out +} + +// SetLevel change log level +func (l *log) SetLevel(v uint32) { + atomic.StoreUint32(&l.status, v) +} + +// GetLevel getting log level +func (l *log) GetLevel() uint32 { + return atomic.LoadUint32(&l.status) +} + +// Infof info message +func (l *log) Infof(format string, args ...interface{}) { + l.getEntity().Infof(format, args...) +} + +// Warnf warning message +func (l *log) Warnf(format string, args ...interface{}) { + l.getEntity().Warnf(format, args...) +} + +// Errorf error message +func (l *log) Errorf(format string, args ...interface{}) { + l.getEntity().Errorf(format, args...) +} + +// Debugf debug message +func (l *log) Debugf(format string, args ...interface{}) { + l.getEntity().Debugf(format, args...) +} + +// Fatalf fatal message and exit +func (l *log) Fatalf(format string, args ...interface{}) { + l.getEntity().Fatalf(format, args...) +} + +// WithFields setter context to log message +func (l *log) WithFields(v Fields) Writer { + return l.getEntity().WithFields(v) +} + +// WithError setter context to log message +func (l *log) WithError(key string, err error) Writer { + return l.getEntity().WithError(key, err) +} + +// WithField setter context to log message +func (l *log) WithField(key string, value interface{}) Writer { + return l.getEntity().WithField(key, value) +} diff --git a/sdk/log/logger_test.go b/sdk/log/logger_test.go new file mode 100644 index 0000000..6a00bb3 --- /dev/null +++ b/sdk/log/logger_test.go @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package log_test + +import ( + "fmt" + "io" + "os" + "testing" + "time" + + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" + "github.com/stretchr/testify/require" +) + +func TestUnit_New(t *testing.T) { + require.NotNil(t, log.Default()) + + filename, err := os.CreateTemp(os.TempDir(), "test_new_default-*.log") + require.NoError(t, err) + + log.SetOutput(filename) + log.SetLevel(log.LevelDebug) + require.Equal(t, log.LevelDebug, log.GetLevel()) + + go log.Infof("async %d", 1) + go log.Warnf("async %d", 2) + go log.Errorf("async %d", 3) + go log.Debugf("async %d", 4) + + log.Infof("sync %d", 1) + log.Warnf("sync %d", 2) + log.Errorf("sync %d", 3) + log.Debugf("sync %d", 4) + + log.WithFields(log.Fields{"ip": "0.0.0.0"}).Infof("context1") + log.WithFields(log.Fields{"nil": nil}).Infof("context2") + log.WithFields(log.Fields{"func": func() {}}).Infof("context3") + + log.WithField("ip", "0.0.0.0").Infof("context4") + log.WithField("nil", nil).Infof("context5") + log.WithField("func", func() {}).Infof("context6") + + log.WithError("err", nil).Infof("context7") + log.WithError("err", fmt.Errorf("er1")).Infof("context8") + + <-time.After(time.Second * 1) + log.Close() + + require.NoError(t, filename.Close()) + data, err := os.ReadFile(filename.Name()) + require.NoError(t, err) + require.NoError(t, os.Remove(filename.Name())) + + sdata := string(data) + require.Contains(t, sdata, `"lvl":"INF","msg":"async 1"`) + require.Contains(t, sdata, `"lvl":"WRN","msg":"async 2"`) + require.Contains(t, sdata, `"lvl":"ERR","msg":"async 3"`) + require.Contains(t, sdata, `"lvl":"DBG","msg":"async 4"`) + require.Contains(t, sdata, `"lvl":"INF","msg":"sync 1"`) + require.Contains(t, sdata, `"lvl":"WRN","msg":"sync 2"`) + require.Contains(t, sdata, `"lvl":"ERR","msg":"sync 3"`) + require.Contains(t, sdata, `"msg":"context1","ctx":{"ip":"0.0.0.0"}`) + require.Contains(t, sdata, `"msg":"context2","ctx":{"nil":null}`) + require.Contains(t, sdata, `"msg":"context3","ctx":{"func":"unsupported field value: (func())`) + require.Contains(t, sdata, `"msg":"context4","ctx":{"ip":"0.0.0.0"}`) + require.Contains(t, sdata, `"msg":"context5","ctx":{"nil":null}`) + require.Contains(t, sdata, `"msg":"context6","ctx":{"func":"unsupported field value: (func())`) + require.Contains(t, sdata, `"msg":"context7","ctx":{"err":null}`) + require.Contains(t, sdata, `"msg":"context8","ctx":{"err":"er1"}`) +} + +func BenchmarkNew(b *testing.B) { + b.ReportAllocs() + + ll := log.New() + ll.SetOutput(io.Discard) + ll.SetLevel(log.LevelDebug) + wg := iosync.NewGroup() + + b.ResetTimer() + b.RunParallel(func(p *testing.PB) { + wg.Background(func() { + for p.Next() { + ll.WithFields(log.Fields{"a": "b"}).Infof("hello") + ll.WithField("a", "b").Infof("hello") + ll.WithError("a", fmt.Errorf("b")).Infof("hello") + } + }) + }) + wg.Wait() + ll.Close() +} diff --git a/sdk/log/message.go b/sdk/log/message.go new file mode 100644 index 0000000..6867e2f --- /dev/null +++ b/sdk/log/message.go @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package log + +import "sync" + +//go:generate easyjson + +var poolMessage = sync.Pool{ + New: func() interface{} { + return newMessage() + }, +} + +//easyjson:json +type message struct { + Time int64 `json:"time"` + Level string `json:"lvl"` + Message string `json:"msg"` + Ctx map[string]interface{} `json:"ctx,omitempty"` +} + +func newMessage() *message { + return &message{ + Ctx: make(map[string]interface{}), + } +} + +func (v *message) Reset() { + v.Time = 0 + v.Level = "" + v.Message = "" + for s := range v.Ctx { + delete(v.Ctx, s) + } +} diff --git a/sdk/log/message_easyjson.go b/sdk/log/message_easyjson.go new file mode 100644 index 0000000..ac7b943 --- /dev/null +++ b/sdk/log/message_easyjson.go @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +// Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. + +package log + +import ( + json "encoding/json" + + easyjson "github.com/mailru/easyjson" + jlexer "github.com/mailru/easyjson/jlexer" + jwriter "github.com/mailru/easyjson/jwriter" +) + +// suppress unused package warning +var ( + _ *json.RawMessage + _ *jlexer.Lexer + _ *jwriter.Writer + _ easyjson.Marshaler +) + +func easyjson4086215fDecodeGithubComOsspkgGoSdkLog(in *jlexer.Lexer, out *message) { + isTopLevel := in.IsStart() + if in.IsNull() { + if isTopLevel { + in.Consumed() + } + in.Skip() + return + } + in.Delim('{') + for !in.IsDelim('}') { + key := in.UnsafeFieldName(false) + in.WantColon() + if in.IsNull() { + in.Skip() + in.WantComma() + continue + } + switch key { + case "time": + out.Time = int64(in.Int64()) + case "lvl": + out.Level = string(in.String()) + case "msg": + out.Message = string(in.String()) + case "ctx": + if in.IsNull() { + in.Skip() + } else { + in.Delim('{') + if !in.IsDelim('}') { + out.Ctx = make(map[string]interface{}) + } else { + out.Ctx = nil + } + for !in.IsDelim('}') { + key := string(in.String()) + in.WantColon() + var v1 interface{} + if m, ok := v1.(easyjson.Unmarshaler); ok { + m.UnmarshalEasyJSON(in) + } else if m, ok := v1.(json.Unmarshaler); ok { + _ = m.UnmarshalJSON(in.Raw()) + } else { + v1 = in.Interface() + } + (out.Ctx)[key] = v1 + in.WantComma() + } + in.Delim('}') + } + default: + in.SkipRecursive() + } + in.WantComma() + } + in.Delim('}') + if isTopLevel { + in.Consumed() + } +} +func easyjson4086215fEncodeGithubComOsspkgGoSdkLog(out *jwriter.Writer, in message) { + out.RawByte('{') + first := true + _ = first + { + const prefix string = ",\"time\":" + out.RawString(prefix[1:]) + out.Int64(int64(in.Time)) + } + { + const prefix string = ",\"lvl\":" + out.RawString(prefix) + out.String(string(in.Level)) + } + { + const prefix string = ",\"msg\":" + out.RawString(prefix) + out.String(string(in.Message)) + } + if len(in.Ctx) != 0 { + const prefix string = ",\"ctx\":" + out.RawString(prefix) + { + out.RawByte('{') + v2First := true + for v2Name, v2Value := range in.Ctx { + if v2First { + v2First = false + } else { + out.RawByte(',') + } + out.String(string(v2Name)) + out.RawByte(':') + if m, ok := v2Value.(easyjson.Marshaler); ok { + m.MarshalEasyJSON(out) + } else if m, ok := v2Value.(json.Marshaler); ok { + out.Raw(m.MarshalJSON()) + } else { + out.Raw(json.Marshal(v2Value)) + } + } + out.RawByte('}') + } + } + out.RawByte('}') +} + +// MarshalJSON supports json.Marshaler interface +func (v message) MarshalJSON() ([]byte, error) { + w := jwriter.Writer{} + easyjson4086215fEncodeGithubComOsspkgGoSdkLog(&w, v) + return w.Buffer.BuildBytes(), w.Error +} + +// MarshalEasyJSON supports easyjson.Marshaler interface +func (v message) MarshalEasyJSON(w *jwriter.Writer) { + easyjson4086215fEncodeGithubComOsspkgGoSdkLog(w, v) +} + +// UnmarshalJSON supports json.Unmarshaler interface +func (v *message) UnmarshalJSON(data []byte) error { + r := jlexer.Lexer{Data: data} + easyjson4086215fDecodeGithubComOsspkgGoSdkLog(&r, v) + return r.Error() +} + +// UnmarshalEasyJSON supports easyjson.Unmarshaler interface +func (v *message) UnmarshalEasyJSON(l *jlexer.Lexer) { + easyjson4086215fDecodeGithubComOsspkgGoSdkLog(l, v) +} diff --git a/sdk/netutil/epoll/common.go b/sdk/netutil/epoll/common.go new file mode 100644 index 0000000..0ca6b08 --- /dev/null +++ b/sdk/netutil/epoll/common.go @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package epoll + +import ( + "github.com/osspkg/goppy/sdk/errors" +) + +var ( + errServAlreadyRunning = errors.New("server already running") + errServAlreadyStopped = errors.New("server already stopped") + errEpollEmptyEvents = errors.New("epoll empty event") +) + +var ( + defaultEOF = []byte("\r\n") +) diff --git a/sdk/netutil/epoll/epoll.go b/sdk/netutil/epoll/epoll.go new file mode 100644 index 0000000..62be07c --- /dev/null +++ b/sdk/netutil/epoll/epoll.go @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package epoll + +import ( + "net" + "sync" + "syscall" + + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/netutil" + "golang.org/x/sys/unix" +) + +const ( + epollEvents = unix.POLLIN | unix.POLLRDHUP | unix.POLLERR | unix.POLLHUP | unix.POLLNVAL + epollEventCount = 100 + epollEventIntervalMS = 500 +) + +type ( + epollEventsSlice []unix.EpollEvent + + epoll struct { + fd int + conn epollNetMap + events epollEventsSlice + nets epollNetSlice + log log.Logger + mux sync.RWMutex + } +) + +func newEpoll(l log.Logger) (*epoll, error) { + fd, err := unix.EpollCreate1(0) + if err != nil { + return nil, err + } + return &epoll{ + fd: fd, + conn: make(epollNetMap), + events: make(epollEventsSlice, epollEventCount), + nets: make(epollNetSlice, epollEventCount), + log: l, + }, nil +} + +func (v *epoll) AddOrClose(c net.Conn) error { + fd := netutil.FileDescriptor(c) + err := unix.EpollCtl(v.fd, syscall.EPOLL_CTL_ADD, fd, &unix.EpollEvent{Events: epollEvents, Fd: int32(fd)}) + if err != nil { + return errors.Wrap(err, c.Close()) + } + v.mux.Lock() + v.conn[fd] = &epollNetItem{Conn: c, Fd: fd} + v.mux.Unlock() + return nil +} + +func (v *epoll) removeFD(fd int) error { + return unix.EpollCtl(v.fd, syscall.EPOLL_CTL_DEL, fd, nil) +} + +func (v *epoll) Close(c *epollNetItem) error { + v.mux.Lock() + defer v.mux.Unlock() + return v.closeConn(c) +} + +func (v *epoll) closeConn(c *epollNetItem) error { + if err := v.removeFD(c.Fd); err != nil { + return err + } + delete(v.conn, c.Fd) + return c.Conn.Close() +} + +func (v *epoll) CloseAll() (err error) { + v.mux.Lock() + defer v.mux.Unlock() + + for _, conn := range v.conn { + if err0 := v.closeConn(conn); err0 != nil { + err = errors.Wrap(err, err0) + } + } + v.conn = make(epollNetMap) + return +} + +func (v *epoll) getConn(fd int) (*epollNetItem, bool) { + v.mux.RLock() + conn, ok := v.conn[fd] + v.mux.RUnlock() + return conn, ok +} + +func (v *epoll) Wait() (epollNetSlice, error) { + n, err := unix.EpollWait(v.fd, v.events, epollEventIntervalMS) + if err != nil { + return nil, err + } + if n <= 0 { + return nil, errEpollEmptyEvents + } + + v.nets = v.nets[:0] + for i := 0; i < n; i++ { + fd := int(v.events[i].Fd) + conn, ok := v.getConn(fd) + if !ok { + if err = v.removeFD(fd); err != nil { + v.log.WithFields(log.Fields{ + "err": err.Error(), + "fd": fd, + }).Errorf("Close fd") + } + continue + } + if conn.IsAwait() { + continue + } + conn.Await(true) + + switch v.events[i].Events { + case unix.POLLIN: + v.nets = append(v.nets, conn) + default: + if err = v.Close(conn); err != nil { + v.log.WithFields(log.Fields{"err": err.Error()}).Errorf("Epoll close connect") + } + } + } + + return v.nets, nil +} diff --git a/sdk/netutil/epoll/epoll_connect.go b/sdk/netutil/epoll/epoll_connect.go new file mode 100644 index 0000000..ffc2474 --- /dev/null +++ b/sdk/netutil/epoll/epoll_connect.go @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package epoll + +import ( + "bytes" + "io" + "sync" + + "github.com/osspkg/goppy/sdk/errors" +) + +var ( + epollBodyPool = sync.Pool{ + New: func() interface{} { + return make([]byte, 0, 1024) + }, + } + + errInvalidPoolType = errors.New("invalid data type from pool") +) + +type Handler func([]byte, io.Writer) error + +func newEpollConn(conn io.ReadWriter, handler Handler, eof []byte) error { + var ( + n int + err error + l = len(eof) + ) + b, ok := epollBodyPool.Get().([]byte) + if !ok { + return errInvalidPoolType + } + defer epollBodyPool.Put(b[:0]) //nolint:staticcheck + + for { + if len(b) == cap(b) { + b = append(b, 0)[:len(b)] + } + n, err = conn.Read(b[len(b):cap(b)]) + b = b[:len(b)+n] + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + if len(b) < l { + return io.EOF + } + if bytes.Equal(eof, b[len(b)-l:]) { + b = b[:len(b)-l] + break + } + } + err = handler(b, conn) + return err +} diff --git a/sdk/netutil/epoll/epoll_net.go b/sdk/netutil/epoll/epoll_net.go new file mode 100644 index 0000000..97c07c2 --- /dev/null +++ b/sdk/netutil/epoll/epoll_net.go @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package epoll + +import ( + "net" + "sync" +) + +type ( + epollNetMap map[int]*epollNetItem + epollNetSlice []*epollNetItem +) + +type epollNetItem struct { + Conn net.Conn + await bool + Fd int + mux sync.RWMutex +} + +func (v *epollNetItem) Await(b bool) { + v.mux.Lock() + v.await = b + v.mux.Unlock() +} + +func (v *epollNetItem) IsAwait() bool { + v.mux.RLock() + is := v.await + v.mux.RUnlock() + return is +} diff --git a/sdk/netutil/epoll/server.go b/sdk/netutil/epoll/server.go new file mode 100644 index 0000000..7cf9b90 --- /dev/null +++ b/sdk/netutil/epoll/server.go @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package epoll + +import ( + "io" + "net" + "time" + + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/netutil" + "golang.org/x/sys/unix" +) + +type ( + Config struct { + Addr string `yaml:"addr"` + ReadTimeout time.Duration `yaml:"read_timeout,omitempty"` + WriteTimeout time.Duration `yaml:"write_timeout,omitempty"` + IdleTimeout time.Duration `yaml:"idle_timeout,omitempty"` + ShutdownTimeout time.Duration `yaml:"shutdown_timeout,omitempty"` + } + + Server struct { + sync iosync.Switch + wg iosync.Group + handler Handler + log log.Logger + conf Config + eof []byte + listener net.Listener + epoll *epoll + } +) + +func New(conf Config, handler Handler, eof []byte, l log.Logger) *Server { + return &Server{ + sync: iosync.NewSwitch(), + wg: iosync.NewGroup(), + conf: conf, + handler: handler, + log: l, + eof: eof, + } +} + +func (s *Server) validate() { + s.conf.Addr = netutil.CheckHostPort(s.conf.Addr) + if len(s.eof) == 0 { + s.eof = defaultEOF + } +} + +func (s *Server) Up(ctx app.Context) (err error) { + if !s.sync.On() { + return errServAlreadyRunning + } + s.validate() + if s.listener, err = net.Listen("tcp", s.conf.Addr); err != nil { + return + } + if s.epoll, err = newEpoll(s.log); err != nil { + return + } + s.log.WithFields(log.Fields{"ip": s.conf.Addr}).Infof("Epoll server started") + s.wg.Background(func() { + s.connAccept(ctx) + }) + s.wg.Background(func() { + s.epollAccept(ctx) + }) + return +} + +func (s *Server) Down() error { + if !s.sync.Off() { + return errServAlreadyStopped + } + err := errors.Wrap(s.epoll.CloseAll(), s.listener.Close()) + s.wg.Wait() + if err != nil { + s.log.WithFields(log.Fields{ + "err": err.Error(), + "ip": s.conf.Addr, + }).Errorf("Epoll server stopped") + return err + } + s.log.WithFields(log.Fields{ + "ip": s.conf.Addr, + }).Infof("Epoll server stopped") + return nil +} + +func (s *Server) connAccept(ctx app.Context) { + defer func() { + ctx.Close() + }() + for { + conn, err := s.listener.Accept() + if err != nil { + select { + case <-ctx.Done(): + return + default: + s.log.WithFields(log.Fields{"err": err.Error()}).Errorf("Epoll conn accept") + //TODO: check error? + //var ne net.Error + //if errors.As(err, ne) { + // time.Sleep(1 * time.Second) + // continue + //} + return + } + } + if err = s.epoll.AddOrClose(conn); err != nil { + s.log.WithFields(log.Fields{ + "err": err.Error(), "ip": conn.RemoteAddr().String(), + }).Errorf("Epoll add conn") + } + } +} + +func (s *Server) epollAccept(ctx app.Context) { + defer func() { + ctx.Close() + }() + for { + select { + case <-ctx.Done(): + return + default: + list, err := s.epoll.Wait() + switch true { + case err == nil: + case errors.Is(err, errEpollEmptyEvents): + continue + case errors.Is(err, unix.EINTR): + continue + default: + s.log.WithFields(log.Fields{ + "err": err.Error(), + }).Errorf("Epoll accept conn") + continue + } + + for _, c := range list { + c := c + go func(conn *epollNetItem) { + defer conn.Await(false) + + if err1 := newEpollConn(conn.Conn, s.handler, s.eof); err1 != nil { + if err2 := s.epoll.Close(conn); err2 != nil { + s.log.WithFields(log.Fields{ + "err": err2.Error(), + "ip": conn.Conn.RemoteAddr().String(), + }).Errorf("Epoll add conn") + } + if errors.Is(err1, io.EOF) { + s.log.WithFields(log.Fields{ + "err": err1.Error(), + "ip": conn.Conn.RemoteAddr().String(), + }).Errorf("Epoll bad conn") + } + } + }(c) + } + } + } +} diff --git a/sdk/netutil/netutils.go b/sdk/netutil/netutils.go new file mode 100644 index 0000000..ae8db9e --- /dev/null +++ b/sdk/netutil/netutils.go @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package netutil + +import ( + "net" + "reflect" + "strings" + + "github.com/osspkg/goppy/sdk/errors" +) + +var ( + ErrResolveTCPAddress = errors.New("resolve tcp address") +) + +func RandomPort(host string) (string, error) { + host = strings.Join([]string{host, "0"}, ":") + addr, err := net.ResolveTCPAddr("tcp", host) + if err != nil { + return host, errors.Wrap(err, ErrResolveTCPAddress) + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + return host, errors.Wrap(err, ErrResolveTCPAddress) + } + v := l.Addr().String() + if err = l.Close(); err != nil { + return host, errors.Wrap(err, ErrResolveTCPAddress) + } + return v, nil +} + +func FileDescriptor(c net.Conn) int { + fd := reflect.Indirect(reflect.ValueOf(c)).FieldByName("fd") + pfd := reflect.Indirect(fd).FieldByName("pfd") + return int(pfd.FieldByName("Sysfd").Int()) +} + +func CheckHostPort(addr string) string { + hp := strings.Split(addr, ":") + if len(hp) != 2 { + tmp := make([]string, 2) + copy(hp, tmp) + hp = tmp + } + if len(hp[0]) == 0 { + hp[0] = "0.0.0.0" + } + if len(hp[1]) == 0 { + if v, err := RandomPort(hp[0]); err == nil { + hp[1] = v + } else { + hp[1] = "8080" + } + } + return strings.Join(hp, ":") +} diff --git a/sdk/netutil/unixsocket/client.go b/sdk/netutil/unixsocket/client.go new file mode 100644 index 0000000..9992804 --- /dev/null +++ b/sdk/netutil/unixsocket/client.go @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package unixsocket + +import ( + "net" + + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/ioutil" +) + +type Client struct { + path string +} + +func NewClient(path string) *Client { + return &Client{ + path: path, + } +} + +func (v *Client) Exec(name string, b []byte) ([]byte, error) { + conn, err := net.Dial("unix", v.path) + if err != nil { + return nil, errors.Wrapf(err, "open connect [unix:%s]", v.path) + } + defer conn.Close() //nolint: errcheck + if err = ioutil.WriteBytes(conn, append([]byte(name+divideStr), b...), newLine); err != nil { + return nil, err + } + return ioutil.ReadBytes(conn, newLine) +} + +func (v *Client) ExecString(name string, b string) ([]byte, error) { + return v.Exec(name, []byte(b)) +} diff --git a/sdk/netutil/unixsocket/common.go b/sdk/netutil/unixsocket/common.go new file mode 100644 index 0000000..66ac29d --- /dev/null +++ b/sdk/netutil/unixsocket/common.go @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package unixsocket + +import ( + "io" + + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/ioutil" +) + +var ( + newLine = "\n" + divideStr = " " + divideByte = byte(' ') + + ErrInvalidCommand = errors.New("command not found") +) + +func writeError(v io.Writer, err error) error { + return ioutil.WriteBytes(v, []byte(err.Error()), newLine) +} + +func parseCommand(b []byte) (string, []byte) { + for i := 0; i < len(b); i++ { + if b[i] == divideByte { + if len(b) > i+2 { + return string(b[0:i]), b[i+1:] + } + return string(b[0:i]), nil + } + } + return string(b), nil +} diff --git a/sdk/netutil/unixsocket/server.go b/sdk/netutil/unixsocket/server.go new file mode 100644 index 0000000..4dd0885 --- /dev/null +++ b/sdk/netutil/unixsocket/server.go @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package unixsocket + +import ( + "io" + "net" + "os" + "sync" + + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/ioutil" +) + +var ( + ErrServAlreadyRunning = errors.New("server already running") + ErrServAlreadyStopped = errors.New("server already stopped") +) + +type ( + CommandHandler func([]byte) ([]byte, error) + + Server struct { + status iosync.Switch + path string + socket net.Listener + commands map[string]CommandHandler + mux sync.RWMutex + logError func(err error) + } +) + +func NewServer(path string) *Server { + return &Server{ + path: path, + status: iosync.NewSwitch(), + commands: make(map[string]CommandHandler), + logError: func(_ error) {}, + } +} + +func (v *Server) ErrorLog(handler func(err error)) { + v.mux.Lock() + v.logError = func(err error) { + if err == nil { + return + } + handler(err) + } + v.mux.Unlock() +} + +func (v *Server) AddCommand(name string, handler CommandHandler) { + v.mux.Lock() + v.commands[name] = handler + v.mux.Unlock() +} + +func (v *Server) Down() error { + v.mux.Lock() + defer v.mux.Unlock() + + if !v.status.Off() { + return ErrServAlreadyStopped + } + if v.socket != nil { + return v.socket.Close() + } + return nil +} + +func (v *Server) Up() error { + if !v.status.On() { + return ErrServAlreadyRunning + } + err := os.Remove(v.path) + if err != nil && !os.IsNotExist(err) { + return errors.Wrapf(err, "remove unix socket") + } + if v.socket, err = net.Listen("unix", v.path); err != nil { + return errors.Wrapf(err, "init unix socket") + } + for { + fd, err := v.socket.Accept() + if err != nil { + return err + } + go v.handler(fd) + } +} + +func (v *Server) handler(rwc io.ReadWriteCloser) { + v.mux.RLock() + defer func() { + v.mux.RUnlock() + v.logError(rwc.Close()) + }() + + b, err := ioutil.ReadBytes(rwc, newLine) + if err != nil { + v.logError(errors.Wrapf(err, "read unix socket request")) + v.logError(errors.Wrapf(writeError(rwc, err), "write unix socket error")) + return + } + command, data := parseCommand(b) + handler, ok := v.commands[command] + if !ok { + v.logError(errors.Wrapf(ErrInvalidCommand, command)) + v.logError(errors.Wrapf(writeError(rwc, ErrInvalidCommand), "write unix socket error")) + return + } + + out, err := handler(data) + if err != nil { + v.logError(errors.Wrapf(err, "call command '%s'", command)) + v.logError(errors.Wrapf(writeError(rwc, err), "write unix socket error")) + return + } + v.logError(errors.Wrapf(ioutil.WriteBytes(rwc, out, newLine), "write unix socket response")) +} diff --git a/sdk/netutil/websocket/client.go b/sdk/netutil/websocket/client.go new file mode 100644 index 0000000..f1a1bca --- /dev/null +++ b/sdk/netutil/websocket/client.go @@ -0,0 +1,233 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +import ( + "context" + "encoding/json" + "net/http" + + ws "github.com/gorilla/websocket" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" +) + +type ( + cli struct { + url string + id string + header http.Header + events map[EventID]ClientHandler + conn *ws.Conn + logger log.Logger + busBuf chan []byte + ctx context.Context + cancel context.CancelFunc + openFunc []func(cid string) + closeFunc []func(cid string) + sync iosync.Switch + mux iosync.Lock + } + + Client interface { + Encode(eid EventID, in interface{}) + ConnectID() string + Header(key, value string) + SetHandler(call ClientHandler, eids ...EventID) + DelHandler(eids ...EventID) + OnClose(cb func(cid string)) + OnOpen(cb func(cid string)) + Close() + DialAndListen() error + } + + ClientOption interface { + Header(key, value string) + } +) + +func NewClient(ctx context.Context, url string, l log.Logger, opts ...func(ClientOption)) Client { + c, cancel := context.WithCancel(ctx) + wcli := &cli{ + url: url, + id: "", + header: make(http.Header), + events: make(map[EventID]ClientHandler, 10), + conn: nil, + logger: l, + busBuf: make(chan []byte, busBufferSize), + ctx: c, + cancel: cancel, + openFunc: make([]func(string), 0, 2), + closeFunc: make([]func(string), 0, 2), + sync: iosync.NewSwitch(), + mux: iosync.NewLock(), + } + for _, opt := range opts { + opt(wcli) + } + return wcli +} + +func (v *cli) ErrLog(cid string, err error, msg string, args ...interface{}) { + if err == nil { + return + } + v.logger.WithFields(log.Fields{"cid": cid, "err": err.Error()}).Errorf(msg, args...) +} + +func (v *cli) ErrLogMessage(cid string, msg string, args ...interface{}) { + v.logger.WithFields(log.Fields{"cid": cid}).Errorf(msg, args...) +} + +func (v *cli) SetHandler(call ClientHandler, eids ...EventID) { + v.mux.Lock(func() { + for _, eid := range eids { + v.events[eid] = call + } + }) +} + +func (v *cli) DelHandler(eids ...EventID) { + v.mux.Lock(func() { + for _, eid := range eids { + delete(v.events, eid) + } + }) +} + +func (v *cli) GetHandler(eid EventID) (h ClientHandler, ok bool) { + v.mux.RLock(func() { + h, ok = v.events[eid] + }) + return +} + +func (v *cli) Header(key, value string) { + v.mux.Lock(func() { + v.header.Set(key, value) + }) +} + +func (v *cli) ConnectID() string { + return v.id +} + +func (v *cli) connect() *ws.Conn { + return v.conn +} + +func (v *cli) cancelFunc() context.CancelFunc { + return v.cancel +} + +func (v *cli) done() <-chan struct{} { + return v.ctx.Done() +} + +func (v *cli) readBus() <-chan []byte { + return v.busBuf +} + +func (v *cli) WriteToBus(b []byte) { + if v.sync.IsOff() { + return + } + if len(b) == 0 { + return + } + select { + case v.busBuf <- b: + default: + v.ErrLogMessage(v.id, "write chan is full") + } +} + +func (v *cli) Encode(eid EventID, in interface{}) { + getEventModel(func(ev *event) { + ev.ID = eid + ev.Encode(in) + b, err := json.Marshal(ev) + if err != nil { + v.ErrLog(v.ConnectID(), err, "[ws] encode message: %d", eid) + return + } + v.WriteToBus(b) + }) +} + +func (v *cli) callHandler(b []byte) { + getEventModel(func(ev *event) { + if err := json.Unmarshal(b, ev); err != nil { + v.ErrLog(v.ConnectID(), err, "[ws] decode message") + return + } + call, ok := v.GetHandler(ev.EventID()) + if !ok { + return + } + call(ev, ev, v) + }) +} + +func (v *cli) OnClose(cb func(cid string)) { + v.mux.Lock(func() { + v.closeFunc = append(v.closeFunc, cb) + }) +} + +func (v *cli) OnOpen(cb func(cid string)) { + v.mux.Lock(func() { + v.openFunc = append(v.openFunc, cb) + }) +} + +func (v *cli) Close() { + if !v.sync.Off() { + return + } + v.cancel() +} + +func (v *cli) DialAndListen() error { + if !v.sync.On() { + return errOneOpenConnect + } + defer v.sync.Off() + + var ( + err error + resp *http.Response + ) + + if v.conn, resp, err = ws.DefaultDialer.DialContext(v.ctx, v.url, v.header); err != nil { + v.ErrLog(v.ConnectID(), err, "open connect [%s]", v.url) + return err + } else { + v.id = resp.Header.Get("Sec-WebSocket-Accept") + } + + defer func() { + if err := resp.Body.Close(); err != nil { + v.ErrLog(v.ConnectID(), err, "close body connect [%s]", v.url) + } + }() + + v.mux.RLock(func() { + for _, fn := range v.openFunc { + fn(v.ConnectID()) + } + }) + setupPingPong(v.connect()) + go pumpWrite(v, v) + pumpRead(v, v) + v.mux.RLock(func() { + for _, fn := range v.closeFunc { + fn(v.ConnectID()) + } + }) + return nil +} diff --git a/sdk/netutil/websocket/common.go b/sdk/netutil/websocket/common.go new file mode 100644 index 0000000..ee3ad8b --- /dev/null +++ b/sdk/netutil/websocket/common.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +import ( + "net/http" + "time" + + ws "github.com/gorilla/websocket" + "github.com/osspkg/goppy/sdk/errors" +) + +const ( + pongWait = 60 * time.Second + pingPeriod = pongWait / 3 + busBufferSize = 128 +) + +var ( + errOneOpenConnect = errors.New("connection can be started once") + errUnknownEventID = errors.New("unknown event id") +) + +func newUpgrader() ws.Upgrader { + return ws.Upgrader{ + EnableCompression: true, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(_ *http.Request) bool { + return true + }, + } +} + +func setupPingPong(c *ws.Conn) { + c.SetPingHandler(func(_ string) error { + return errors.Wrap( + c.SetReadDeadline(time.Now().Add(pongWait)), + //v.conn.SetWriteDeadline(time.Now().Add(pongWait)), + ) + }) + c.SetPongHandler(func(_ string) error { + return errors.Wrap( + c.SetReadDeadline(time.Now().Add(pongWait)), + //v.conn.SetWriteDeadline(time.Now().Add(pongWait)), + ) + }) +} diff --git a/sdk/netutil/websocket/connect.go b/sdk/netutil/websocket/connect.go new file mode 100644 index 0000000..3345129 --- /dev/null +++ b/sdk/netutil/websocket/connect.go @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/websocket" + context2 "github.com/osspkg/goppy/sdk/context" + "github.com/osspkg/goppy/sdk/iosync" +) + +type ( + actionsApi interface { + ErrLog(cid string, err error, msg string, args ...interface{}) + ErrLogMessage(cid string, msg string, args ...interface{}) + GetHandler(eid EventID) (EventHandler, bool) + } + + Connect struct { + id string + header http.Header + actions actionsApi + conn *websocket.Conn + busBuf chan []byte + ctx context.Context + cancel context.CancelFunc + openFunc []func(cid string) + closeFunc []func(cid string) + sync iosync.Switch + mux iosync.Lock + } +) + +func NewConnect( + id string, head http.Header, + act actionsApi, conn *websocket.Conn, + ctxs ...context.Context, +) *Connect { + ctx, cancel := context2.Combine(ctxs...) + return &Connect{ + id: id, + header: head, + actions: act, + conn: conn, + busBuf: make(chan []byte, busBufferSize), + ctx: ctx, + cancel: cancel, + closeFunc: make([]func(string), 0, 2), + openFunc: make([]func(string), 0, 2), + sync: iosync.NewSwitch(), + mux: iosync.NewLock(), + } +} + +func (v *Connect) ConnectID() string { + return v.id +} + +func (v *Connect) Head(key string) string { + return v.header.Get(key) +} + +func (v *Connect) connect() *websocket.Conn { + return v.conn +} + +func (v *Connect) cancelFunc() context.CancelFunc { + return v.cancel +} + +func (v *Connect) done() <-chan struct{} { + return v.ctx.Done() +} + +func (v *Connect) readBus() <-chan []byte { + return v.busBuf +} + +func (v *Connect) WriteToBus(b []byte) { + if len(b) == 0 { + return + } + select { + case v.busBuf <- b: + default: + v.actions.ErrLogMessage(v.id, "write chan is full") + } +} + +func (v *Connect) Encode(eid EventID, in interface{}) { + getEventModel(func(ev *event) { + ev.ID = eid + ev.Encode(in) + b, err := json.Marshal(ev) + if err != nil { + v.actions.ErrLog(v.ConnectID(), err, "[ws] encode message: %d", eid) + return + } + v.WriteToBus(b) + }) +} + +func (v *Connect) callHandler(b []byte) { + getEventModel(func(ev *event) { + if err := json.Unmarshal(b, ev); err != nil { + v.actions.ErrLog(v.ConnectID(), err, "[ws] decode message") + return + } + call, ok := v.actions.GetHandler(ev.EventID()) + if !ok { + ev.Error(errUnknownEventID) + } else if err := call(ev, ev, v); err != nil { + ev.Error(err) + } + if bb, err := json.Marshal(ev); err != nil { + v.actions.ErrLog(v.ConnectID(), err, "[ws] encode message: %d", ev.EventID()) + } else { + v.WriteToBus(bb) + } + }) +} + +func (v *Connect) OnClose(cb func(cid string)) { + v.mux.Lock(func() { + v.closeFunc = append(v.closeFunc, cb) + }) +} + +func (v *Connect) OnOpen(cb func(cid string)) { + v.mux.Lock(func() { + v.openFunc = append(v.openFunc, cb) + }) +} + +func (v *Connect) Close() { + if !v.sync.Off() { + return + } + v.actions.ErrLog(v.ConnectID(), v.conn.Close(), "close connect") +} + +func (v *Connect) Run() { + if !v.sync.On() { + return + } + v.mux.RLock(func() { + for _, fn := range v.openFunc { + fn(v.ConnectID()) + } + }) + setupPingPong(v.conn) + go pumpWrite(v, v.actions) + pumpRead(v, v.actions) + v.mux.RLock(func() { + for _, fn := range v.closeFunc { + fn(v.ConnectID()) + } + }) +} diff --git a/sdk/netutil/websocket/event.go b/sdk/netutil/websocket/event.go new file mode 100644 index 0000000..9896cba --- /dev/null +++ b/sdk/netutil/websocket/event.go @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +//go:generate easyjson + +import ( + "encoding/json" + "sync" +) + +var ( + poolEvents = sync.Pool{New: func() interface{} { return &event{} }} +) + +type EventID uint16 + +//easyjson:json +type event struct { + ID EventID `json:"e"` + Data json.RawMessage `json:"d"` + Err *string `json:"err,omitempty"` +} + +func getEventModel(call func(ev *event)) { + m, ok := poolEvents.Get().(*event) + if !ok { + m = &event{} + } + call(m) + poolEvents.Put(m.Reset()) +} + +func (v *event) EventID() EventID { + return v.ID +} + +func (v *event) Decode(in interface{}) error { + return json.Unmarshal(v.Data, in) +} + +func (v *event) Encode(in interface{}) { + b, err := json.Marshal(in) + if err != nil { + v.Error(err) + return + } + v.Body(b) +} + +func (v *event) EncodeEvent(id EventID, in interface{}) { + v.ID = id + v.Encode(in) +} + +func (v *event) Reset() *event { + v.ID, v.Err, v.Data = 0, nil, v.Data[:0] + return v +} + +func (v *event) Error(e error) { + if e == nil { + return + } + err := e.Error() + v.Err, v.Data = &err, v.Data[:0] +} + +func (v *event) Body(b []byte) { + v.Err, v.Data = nil, append(v.Data[:0], b...) +} diff --git a/plugins/web/ws_event_easyjson.go b/sdk/netutil/websocket/event_easyjson.go similarity index 71% rename from plugins/web/ws_event_easyjson.go rename to sdk/netutil/websocket/event_easyjson.go index b8e32bb..aac6095 100644 --- a/plugins/web/ws_event_easyjson.go +++ b/sdk/netutil/websocket/event_easyjson.go @@ -1,6 +1,11 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + // Code generated by easyjson for marshaling/unmarshaling. DO NOT EDIT. -package web +package websocket import ( json "encoding/json" @@ -17,7 +22,7 @@ var ( _ easyjson.Marshaler ) -func easyjson3146fba7DecodeGithubComOsspkgGoppyPluginsWeb(in *jlexer.Lexer, out *event) { +func easyjsonF642ad3eDecodeGithubComOsspkgGoppySdkNetutilWebsocket(in *jlexer.Lexer, out *event) { isTopLevel := in.IsStart() if in.IsNull() { if isTopLevel { @@ -37,7 +42,7 @@ func easyjson3146fba7DecodeGithubComOsspkgGoppyPluginsWeb(in *jlexer.Lexer, out } switch key { case "e": - out.ID = uint(in.Uint()) + out.ID = EventID(in.Uint16()) case "d": if data := in.Raw(); in.Ok() { in.AddError((out.Data).UnmarshalJSON(data)) @@ -52,10 +57,6 @@ func easyjson3146fba7DecodeGithubComOsspkgGoppyPluginsWeb(in *jlexer.Lexer, out } *out.Err = string(in.String()) } - case "u": - if data := in.Raw(); in.Ok() { - in.AddError((out.UID).UnmarshalJSON(data)) - } default: in.SkipRecursive() } @@ -66,14 +67,14 @@ func easyjson3146fba7DecodeGithubComOsspkgGoppyPluginsWeb(in *jlexer.Lexer, out in.Consumed() } } -func easyjson3146fba7EncodeGithubComOsspkgGoppyPluginsWeb(out *jwriter.Writer, in event) { +func easyjsonF642ad3eEncodeGithubComOsspkgGoppySdkNetutilWebsocket(out *jwriter.Writer, in event) { out.RawByte('{') first := true _ = first { const prefix string = ",\"e\":" out.RawString(prefix[1:]) - out.Uint(uint(in.ID)) + out.Uint16(uint16(in.ID)) } { const prefix string = ",\"d\":" @@ -85,34 +86,29 @@ func easyjson3146fba7EncodeGithubComOsspkgGoppyPluginsWeb(out *jwriter.Writer, i out.RawString(prefix) out.String(string(*in.Err)) } - if len(in.UID) != 0 { - const prefix string = ",\"u\":" - out.RawString(prefix) - out.Raw((in.UID).MarshalJSON()) - } out.RawByte('}') } // MarshalJSON supports json.Marshaler interface func (v event) MarshalJSON() ([]byte, error) { w := jwriter.Writer{} - easyjson3146fba7EncodeGithubComOsspkgGoppyPluginsWeb(&w, v) + easyjsonF642ad3eEncodeGithubComOsspkgGoppySdkNetutilWebsocket(&w, v) return w.Buffer.BuildBytes(), w.Error } // MarshalEasyJSON supports easyjson.Marshaler interface func (v event) MarshalEasyJSON(w *jwriter.Writer) { - easyjson3146fba7EncodeGithubComOsspkgGoppyPluginsWeb(w, v) + easyjsonF642ad3eEncodeGithubComOsspkgGoppySdkNetutilWebsocket(w, v) } // UnmarshalJSON supports json.Unmarshaler interface func (v *event) UnmarshalJSON(data []byte) error { r := jlexer.Lexer{Data: data} - easyjson3146fba7DecodeGithubComOsspkgGoppyPluginsWeb(&r, v) + easyjsonF642ad3eDecodeGithubComOsspkgGoppySdkNetutilWebsocket(&r, v) return r.Error() } // UnmarshalEasyJSON supports easyjson.Unmarshaler interface func (v *event) UnmarshalEasyJSON(l *jlexer.Lexer) { - easyjson3146fba7DecodeGithubComOsspkgGoppyPluginsWeb(l, v) + easyjsonF642ad3eDecodeGithubComOsspkgGoppySdkNetutilWebsocket(l, v) } diff --git a/plugins/web/ws_event_test.go b/sdk/netutil/websocket/event_test.go similarity index 91% rename from plugins/web/ws_event_test.go rename to sdk/netutil/websocket/event_test.go index 09143a7..412b6e4 100644 --- a/plugins/web/ws_event_test.go +++ b/sdk/netutil/websocket/event_test.go @@ -3,7 +3,7 @@ * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. */ -package web +package websocket import ( "encoding/json" @@ -20,13 +20,13 @@ func TestUnit_Event(t *testing.T) { b, err := json.Marshal(ev) require.NoError(t, err) - require.Equal(t, string(b), "{\"e\":1001,\"d\":{\"token\":\"12345\",\"os\":\"debian\"},\"u\":\"1111\"}") + require.Equal(t, string(b), "{\"e\":1001,\"d\":{\"token\":\"12345\",\"os\":\"debian\"}}") ev.Error(fmt.Errorf("error1")) b, err = json.Marshal(ev) require.NoError(t, err) - require.Equal(t, string(b), "{\"e\":1001,\"d\":null,\"err\":\"error1\",\"u\":\"1111\"}") + require.Equal(t, string(b), "{\"e\":1001,\"d\":null,\"err\":\"error1\"}") ev.Reset() diff --git a/sdk/netutil/websocket/handler.go b/sdk/netutil/websocket/handler.go new file mode 100644 index 0000000..d4859b7 --- /dev/null +++ b/sdk/netutil/websocket/handler.go @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +type ( + ClientHandler func(w CRequest, r CResponse, m CMeta) + + CResponse interface { + EventID() EventID + Decode(in interface{}) error + } + + CRequest interface { + Encode(in interface{}) + EncodeEvent(id EventID, in interface{}) + } + + CMeta interface { + ConnectID() string + } +) + +type ( + EventHandler func(w Response, r Request, m Meta) error + + Request interface { + EventID() EventID + Decode(in interface{}) error + } + + Response interface { + Encode(in interface{}) + EncodeEvent(id EventID, in interface{}) + } + + Meta interface { + ConnectID() string + Head(key string) string + OnClose(cb func(cid string)) + OnOpen(cb func(cid string)) + } +) diff --git a/sdk/netutil/websocket/observable.go b/sdk/netutil/websocket/observable.go new file mode 100644 index 0000000..50c9a5b --- /dev/null +++ b/sdk/netutil/websocket/observable.go @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +import ( + "context" + "sync/atomic" + "time" + + "github.com/osspkg/goppy/sdk/iosync" +) + +type ( + Observable interface { + Subscribe(eid EventID, in interface{}) Subscription + } + + ObservableClient interface { + OnClose(cb func(cid string)) + Encode(eid EventID, in interface{}) + DelHandler(eids ...EventID) + SetHandler(call ClientHandler, eids ...EventID) + } + + _obs struct { + cli ObservableClient + } +) + +func NewObservable(cli ObservableClient) Observable { + return &_obs{ + cli: cli, + } +} + +func (v *_obs) Subscribe(eid EventID, in interface{}) Subscription { + ctx, cncl := context.WithCancel(context.TODO()) + sub := &_sub{ + eid: eid, + count: 0, + ctx: ctx, + cncl: cncl, + cli: v.cli, + call: func() { + if in == nil { + return + } + v.cli.Encode(eid, in) + }, + sync: iosync.NewSwitch(), + } + v.cli.OnClose(func(_ string) { + sub.Unsubscribe() + }) + return sub +} + +type ( + cliApi interface { + DelHandler(eids ...EventID) + SetHandler(call ClientHandler, eids ...EventID) + } + _sub struct { + eid EventID + count uint64 + maxCount uint64 + ctx context.Context + cncl context.CancelFunc + cli cliApi + call func() + sync iosync.Switch + } + + Subscription interface { + Listen(call func(ListenArg), pipe ...PipeFunc) + Unsubscribe() + } + + ListenArg interface { + Decode(in interface{}) error + } +) + +func (v *_sub) Listen(call func(ListenArg), pipe ...PipeFunc) { + if !v.sync.On() { + return + } + for _, fn := range pipe { + v.ctx = fn(v.ctx) + } + if tc, ok := v.ctx.Value(pipeTakeKey).(uint64); ok { + v.maxCount = tc + } + v.cli.SetHandler(func() func(w CRequest, r CResponse, m CMeta) { + return func(_ CRequest, r CResponse, _ CMeta) { + atomic.AddUint64(&v.count, 1) + call(r) + if v.maxCount > 0 && atomic.LoadUint64(&v.count) >= v.maxCount { + v.Unsubscribe() + } + } + }(), v.eid) + v.call() + <-v.ctx.Done() + v.cli.DelHandler(v.eid) +} + +func (v *_sub) Unsubscribe() { + v.cncl() +} + +type ( + PipeFunc func(ctx context.Context) context.Context + pipeKey string +) + +var ( + pipeTakeKey pipeKey = "take" +) + +func PipeTimeout(t time.Duration) PipeFunc { + return func(ctx context.Context) context.Context { + c, _ := context.WithTimeout(ctx, t) //nolint: govet + return c + } +} + +func PipeTake(count uint64) PipeFunc { + return func(ctx context.Context) context.Context { + return context.WithValue(ctx, pipeTakeKey, count) + } +} diff --git a/sdk/netutil/websocket/pump.go b/sdk/netutil/websocket/pump.go new file mode 100644 index 0000000..782e007 --- /dev/null +++ b/sdk/netutil/websocket/pump.go @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +import ( + "context" + "strings" + "time" + + ws "github.com/gorilla/websocket" + "github.com/osspkg/goppy/sdk/errors" +) + +type ( + pumpApi interface { + ConnectID() string + callHandler(b []byte) + readBus() <-chan []byte + connect() *ws.Conn + cancelFunc() context.CancelFunc + done() <-chan struct{} + Close() + } + pumpActionsApi interface { + ErrLog(cid string, err error, msg string, args ...interface{}) + } +) + +func isClosingError(err error) bool { + if ws.IsCloseError(err, ws.CloseNormalClosure, ws.CloseGoingAway, ws.CloseNoStatusReceived) || + strings.Contains(err.Error(), "use of closed network connection") || + errors.Is(err, ws.ErrCloseSent) { + return true + } + return false +} + +func pumpRead(p pumpApi, a pumpActionsApi) { + defer p.cancelFunc() + for { + _, message, err := p.connect().ReadMessage() + if err != nil { + if !isClosingError(err) { + a.ErrLog(p.ConnectID(), err, "[ws] read message") + } + return + } + go p.callHandler(message) + } +} + +func pumpWrite(p pumpApi, a pumpActionsApi) { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + a.ErrLog(p.ConnectID(), p.connect().Close(), "close connect") + }() + for { + select { + case <-p.done(): + err := p.connect().WriteControl(ws.CloseMessage, + ws.FormatCloseMessage(ws.CloseNormalClosure, "Bye bye!"), time.Now().Add(pongWait)) + if err != nil && !isClosingError(err) { + a.ErrLog(p.ConnectID(), err, "[ws] send close") + } + return + case <-ticker.C: + if err := p.connect().WriteControl(ws.PingMessage, nil, time.Now().Add(pongWait)); err != nil { + if !isClosingError(err) { + a.ErrLog(p.ConnectID(), err, "[ws] send ping") + } + return + } + case m := <-p.readBus(): + if err := p.connect().WriteMessage(ws.TextMessage, m); err != nil { + if !isClosingError(err) { + a.ErrLog(p.ConnectID(), err, "[ws] send message") + } + return + } + } + } +} diff --git a/sdk/netutil/websocket/server.go b/sdk/netutil/websocket/server.go new file mode 100644 index 0000000..11e46fc --- /dev/null +++ b/sdk/netutil/websocket/server.go @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package websocket + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/websocket" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" +) + +type Server struct { + clients map[string]*Connect + events map[EventID]EventHandler + upgrade websocket.Upgrader + logger log.Logger + ctx context.Context + cancel context.CancelFunc + mux iosync.Lock + wg iosync.Group +} + +func NewServer(l log.Logger, ctx context.Context, opts ...func(u websocket.Upgrader)) *Server { + up := newUpgrader() + c, cancel := context.WithCancel(ctx) + for _, opt := range opts { + opt(up) + } + return &Server{ + clients: make(map[string]*Connect, 100), + events: make(map[EventID]EventHandler, 10), + upgrade: up, + logger: l, + ctx: c, + cancel: cancel, + mux: iosync.NewLock(), + wg: iosync.NewGroup(), + } +} + +func (v *Server) CloseAll() { + v.cancel() + v.wg.Wait() +} + +func (v *Server) ErrLog(cid string, err error, msg string, args ...interface{}) { + if err == nil { + return + } + v.logger.WithFields(log.Fields{"cid": cid, "err": err.Error()}).Errorf(msg, args...) +} + +func (v *Server) ErrLogMessage(cid string, msg string, args ...interface{}) { + v.logger.WithFields(log.Fields{"cid": cid}).Errorf(msg, args...) +} + +func (v *Server) CountConn() (cc int) { + v.mux.Lock(func() { + cc = len(v.clients) + }) + return +} + +func (v *Server) AddConn(c *Connect) { + v.mux.Lock(func() { + v.clients[c.ConnectID()] = c + }) +} + +func (v *Server) DelConn(id string) { + v.mux.Lock(func() { + delete(v.clients, id) + }) +} + +func (v *Server) SetHandler(call EventHandler, eids ...EventID) { + v.mux.Lock(func() { + for _, eid := range eids { + v.events[eid] = call + } + }) +} + +func (v *Server) GetHandler(eid EventID) (h EventHandler, ok bool) { + v.mux.RLock(func() { + h, ok = v.events[eid] + }) + return +} + +func (v *Server) Broadcast(eid EventID, m interface{}) { + getEventModel(func(ev *event) { + ev.ID = eid + b, err := json.Marshal(m) + if err != nil { + v.ErrLog("*", err, "[ws] broadcast error") + return + } + ev.Body(b) + + b, err = json.Marshal(ev) + if err != nil { + v.ErrLog("*", err, "[ws] broadcast error") + return + } + v.mux.RLock(func() { + for _, c := range v.clients { + c.WriteToBus(b) + } + }) + }) +} + +func (v *Server) SendEvent(eid EventID, m interface{}, cids ...string) { + getEventModel(func(ev *event) { + ev.ID = eid + b, err := json.Marshal(m) + if err != nil { + v.ErrLog("*", err, "[ws] send event error") + return + } + ev.Body(b) + b, err = json.Marshal(ev) + if err != nil { + v.ErrLog("*", err, "[ws] send event error") + return + } + v.mux.RLock(func() { + for _, cid := range cids { + if c, ok := v.clients[cid]; ok { + c.WriteToBus(b) + } + } + }) + }) +} + +func (v *Server) Handling(w http.ResponseWriter, r *http.Request) { + v.wg.Run(func() { + cid := r.Header.Get("Sec-Websocket-Key") + up, err := v.upgrade.Upgrade(w, r, nil) + if err != nil { + v.ErrLog(cid, err, "[ws] upgrade") + w.WriteHeader(http.StatusBadRequest) + return + } + c := NewConnect(cid, r.Header, v, up, r.Context(), v.ctx) + c.OnClose(func(cid string) { + v.DelConn(cid) + }) + c.OnOpen(func(string) { + v.AddConn(c) + }) + c.Run() + }) +} diff --git a/sdk/orm/db.go b/sdk/orm/db.go new file mode 100644 index 0000000..f4a1d25 --- /dev/null +++ b/sdk/orm/db.go @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package orm + +import ( + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/orm/plugins" + "github.com/osspkg/goppy/sdk/orm/schema" +) + +type ( + //_db connection storage + _db struct { + conn schema.Connector + opts *options + } + + Database interface { + Pool(name string) Stmt + Dialect() string + } + + options struct { + Logger log.Logger + Metrics plugins.MetricExecutor + } + + PluginSetup func(o *options) +) + +func UsePluginLogger(l log.Logger) PluginSetup { + return func(o *options) { + o.Logger = l + } +} + +func UsePluginMetric(m plugins.MetricExecutor) PluginSetup { + return func(o *options) { + o.Metrics = m + } +} + +// New init database connections +func New(c schema.Connector, opts ...PluginSetup) Database { + o := &options{ + Logger: plugins.DevNullLog, + Metrics: plugins.DevNullMetric, + } + + for _, opt := range opts { + opt(o) + } + + return &_db{ + conn: c, + opts: o, + } +} + +// Pool getting pool connections by name +func (v *_db) Pool(name string) Stmt { + return newStmt(name, v.conn, v.opts) +} + +func (v *_db) Dialect() string { + return v.conn.Dialect() +} diff --git a/sdk/orm/plugins/devnull.go b/sdk/orm/plugins/devnull.go new file mode 100644 index 0000000..c612914 --- /dev/null +++ b/sdk/orm/plugins/devnull.go @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package plugins + +import ( + "io" + + "github.com/osspkg/goppy/sdk/log" +) + +var ( + DevNullLog log.Logger = &devNullLogger{} + DevNullMetric MetricExecutor = new(devNullMetric) +) + +type devNullMetric struct{} + +func (devNullMetric) ExecutionTime(_ string, call func()) { call() } + +type devNullLogger struct{} + +func (devNullLogger) SetOutput(io.Writer) {} +func (devNullLogger) Fatalf(string, ...interface{}) {} +func (devNullLogger) Errorf(string, ...interface{}) {} +func (devNullLogger) Warnf(string, ...interface{}) {} +func (devNullLogger) Infof(string, ...interface{}) {} +func (devNullLogger) Debugf(string, ...interface{}) {} +func (devNullLogger) SetLevel(v uint32) {} +func (devNullLogger) Close() {} +func (devNullLogger) GetLevel() uint32 { return 0 } +func (v devNullLogger) WithFields(_ log.Fields) log.Writer { return v } +func (v devNullLogger) WithField(_ string, _ interface{}) log.Writer { return v } +func (v devNullLogger) WithError(_ string, _ error) log.Writer { return v } diff --git a/sdk/orm/plugins/logger.go b/sdk/orm/plugins/logger.go new file mode 100644 index 0000000..2364213 --- /dev/null +++ b/sdk/orm/plugins/logger.go @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package plugins + +import ( + "github.com/osspkg/goppy/sdk/log" +) + +var ( + //StdOutLog simple stdout debug log + StdOutLog = func() log.Logger { + l := log.Default() + l.SetLevel(log.LevelDebug) + l.SetOutput(StdOutWriter) + return l + }() +) diff --git a/sdk/orm/plugins/metric.go b/sdk/orm/plugins/metric.go new file mode 100644 index 0000000..e02c808 --- /dev/null +++ b/sdk/orm/plugins/metric.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package plugins + +import ( + "time" +) + +type ( + metric struct { + metrics MetricWriter + } + //MetricExecutor interface + MetricExecutor interface { + ExecutionTime(name string, call func()) + } + //MetricWriter interface + MetricWriter interface { + Metric(name string, time time.Duration) + } +) + +// StdOutMetric simple stdout metrig writer +var StdOutMetric = NewMetric(StdOutWriter) + +// NewMetric init new metric +func NewMetric(m MetricWriter) MetricExecutor { + return &metric{metrics: m} +} + +// ExecutionTime calculating the execution time +func (m *metric) ExecutionTime(name string, call func()) { + if m.metrics == nil { + call() + return + } + + t := time.Now() + call() + m.metrics.Metric(name, time.Since(t)) +} diff --git a/sdk/orm/plugins/metric_test.go b/sdk/orm/plugins/metric_test.go new file mode 100644 index 0000000..79a44c9 --- /dev/null +++ b/sdk/orm/plugins/metric_test.go @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package plugins + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewMetric(t *testing.T) { + w := &bytes.Buffer{} + tl := &stdout{Writer: w} + + demo1 := NewMetric(nil) + demo1.ExecutionTime("hello1", func() {}) + + demo2 := NewMetric(tl) + demo2.ExecutionTime("hello2", func() {}) + + result := w.String() + require.NotContains(t, result, "hello1") + require.Contains(t, result, "hello2") +} diff --git a/sdk/orm/plugins/stdout.go b/sdk/orm/plugins/stdout.go new file mode 100644 index 0000000..40d44d5 --- /dev/null +++ b/sdk/orm/plugins/stdout.go @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package plugins + +import ( + "fmt" + "io" + "os" + "time" +) + +type stdout struct { + Writer io.Writer +} + +// StdOutWriter simple stdout writer +var StdOutWriter = &stdout{Writer: os.Stdout} + +func (s *stdout) currentTime() string { + return time.Now().Format(time.RFC3339) +} + +// Write metric +func (s *stdout) Write(p []byte) (n int, err error) { + return s.Writer.Write(p) +} + +// Metric write metric to log +func (s *stdout) Metric(name string, t time.Duration) { + fmt.Fprintf(s, "[MTRC] %s - %s: %s\n", s.currentTime(), name, t) //nolint:errcheck +} diff --git a/sdk/orm/plugins/stdout_test.go b/sdk/orm/plugins/stdout_test.go new file mode 100644 index 0000000..ba602c4 --- /dev/null +++ b/sdk/orm/plugins/stdout_test.go @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package plugins + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestStdOut(t *testing.T) { + w := &bytes.Buffer{} + + tl := &stdout{Writer: w} + + _, err := tl.Write([]byte("h4gbffke9")) + require.NoError(t, err) + tl.Metric("15gh7netd8", time.Minute) + require.NoError(t, err) + + result := w.String() + require.Contains(t, result, "h4gbffke9") + require.Contains(t, result, "15gh7netd8: 1m0s") +} diff --git a/sdk/orm/schema/common.go b/sdk/orm/schema/common.go new file mode 100644 index 0000000..fcf7d9c --- /dev/null +++ b/sdk/orm/schema/common.go @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package schema + +import ( + "database/sql" + "time" + + "github.com/osspkg/goppy/sdk/errors" +) + +var ( + ErrPoolNotFound = errors.New("pool not found") +) + +const ( + MySQLDialect = "mysql" + SQLiteDialect = "sqlite" + PgSQLDialect = "pgsql" +) + +type ( + //ConfigInterface interface of configs + ConfigInterface interface { + List() []ItemInterface + } + //ItemInterface config item interface + ItemInterface interface { + GetName() string + GetDSN() string + Setup(SetupInterface) + } + //SetupInterface connections setup interface + SetupInterface interface { + SetMaxIdleConns(int) + SetMaxOpenConns(int) + SetConnMaxLifetime(time.Duration) + } + //Connector interface of connection + Connector interface { + Dialect() string + Pool(string) (*sql.DB, error) + Reconnect() error + Close() error + } +) diff --git a/sdk/orm/schema/mysql/mysql.go b/sdk/orm/schema/mysql/mysql.go new file mode 100644 index 0000000..5cc5cb4 --- /dev/null +++ b/sdk/orm/schema/mysql/mysql.go @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package mysql + +import ( + "database/sql" + "fmt" + "net/url" + "sync" + "time" + + _ "github.com/go-sql-driver/mysql" //nolint: golint + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/orm/schema" +) + +const ( + defaultTimeout = time.Second * 5 + defaultTimeoutConn = time.Second * 60 +) + +var ( + _ schema.Connector = (*pool)(nil) + _ schema.ConfigInterface = (*Config)(nil) +) + +type ( + //Config pool of configs + Config struct { + Pool []Item `yaml:"mysql"` + } + + //Item config model + Item struct { + Name string `yaml:"name"` + Host string `yaml:"host"` + Port int `yaml:"port"` + Schema string `yaml:"schema"` + User string `yaml:"user"` + Password string `yaml:"password"` + Timezone string `yaml:"timezone"` + TxIsolationLevel string `yaml:"txisolevel"` + Charset string `yaml:"charset"` + Collation string `yaml:"collation"` + MaxIdleConn int `yaml:"maxidleconn"` + MaxOpenConn int `yaml:"maxopenconn"` + InterpolateParams bool `yaml:"interpolateparams"` + MaxConnTTL time.Duration `yaml:"maxconnttl"` + Timeout time.Duration `yaml:"timeout"` + ReadTimeout time.Duration `yaml:"readtimeout"` + WriteTimeout time.Duration `yaml:"writetimeout"` + OtherParams string `yaml:"other_params"` + } + + pool struct { + conf schema.ConfigInterface + db map[string]*sql.DB + l sync.RWMutex + } +) + +// List getting all configs +func (c *Config) List() (list []schema.ItemInterface) { + for _, item := range c.Pool { + list = append(list, item) + } + return +} + +// GetName getting config name +func (i Item) GetName() string { + return i.Name +} + +// Setup setting config conntections params +func (i Item) Setup(s schema.SetupInterface) { + s.SetMaxIdleConns(i.MaxIdleConn) + s.SetMaxOpenConns(i.MaxOpenConn) + s.SetConnMaxLifetime(i.MaxConnTTL) +} + +// GetDSN connection params +func (i Item) GetDSN() string { + params, err := url.ParseQuery(i.OtherParams) + if err != nil { + params = url.Values{} + } + + params.Add("autocommit", "true") + params.Add("interpolateParams", fmt.Sprintf("%t", i.InterpolateParams)) + + //--- + if len(i.Charset) == 0 { + i.Charset = "utf8mb4" + } + params.Add("charset", i.Charset) + //--- + if len(i.Collation) == 0 { + i.Collation = "utf8mb4_unicode_ci" + } + params.Add("collation", i.Collation) + //--- + if i.Timeout == 0 { + i.Timeout = defaultTimeoutConn + } + params.Add("timeout", i.Timeout.String()) + //--- + if i.ReadTimeout == 0 { + i.ReadTimeout = defaultTimeout + } + params.Add("readTimeout", i.ReadTimeout.String()) + //--- + if i.WriteTimeout == 0 { + i.WriteTimeout = defaultTimeout + } + params.Add("writeTimeout", i.WriteTimeout.String()) + //--- + if len(i.TxIsolationLevel) > 0 { + params.Add("transaction_isolation", i.TxIsolationLevel) + } + //--- + if len(i.Timezone) == 0 { + i.Timezone = "UTC" + } + params.Add("loc", i.Timezone) + //--- + + //--- + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", i.User, i.Password, i.Host, i.Port, i.Schema, params.Encode()) +} + +// New init new mysql connection +func New(conf schema.ConfigInterface) schema.Connector { + c := &pool{ + conf: conf, + db: make(map[string]*sql.DB), + } + + return c +} + +// Dialect getting sql dialect +func (p *pool) Dialect() string { + return schema.MySQLDialect +} + +// Reconnect update connection to database +func (p *pool) Reconnect() error { + if err := p.Close(); err != nil { + return err + } + + p.l.Lock() + defer p.l.Unlock() + + for _, item := range p.conf.List() { + db, err := sql.Open("mysql", item.GetDSN()) + if err != nil { + if er := p.Close(); er != nil { + return errors.Wrap(err, er) + } + return err + } + item.Setup(db) + p.db[item.GetName()] = db + } + return nil +} + +// Close closing connection +func (p *pool) Close() error { + p.l.Lock() + defer p.l.Unlock() + + if len(p.db) > 0 { + for name, db := range p.db { + if err := db.Close(); err != nil { + return err + } + delete(p.db, name) + } + } + return nil +} + +// Pool getting connection pool by name +func (p *pool) Pool(name string) (*sql.DB, error) { + p.l.RLock() + defer p.l.RUnlock() + + db, ok := p.db[name] + if !ok { + return nil, schema.ErrPoolNotFound + } + return db, db.Ping() +} diff --git a/sdk/orm/schema/postgresql/postgresql.go b/sdk/orm/schema/postgresql/postgresql.go new file mode 100644 index 0000000..3b6e4c7 --- /dev/null +++ b/sdk/orm/schema/postgresql/postgresql.go @@ -0,0 +1,179 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package postgresql + +import ( + "database/sql" + "fmt" + "net/url" + "sync" + "time" + + _ "github.com/lib/pq" //nolint: golint + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/orm/schema" +) + +const ( + defaultTimeout = time.Second * 5 + defaultTimeoutConn = time.Second * 60 +) + +var ( + _ schema.Connector = (*pool)(nil) + _ schema.ConfigInterface = (*Config)(nil) +) + +type ( + //Config pool of configs + Config struct { + Pool []Item `yaml:"postgresql"` + } + + //Item config model + Item struct { + Name string `yaml:"name"` + Host string `yaml:"host"` + Port int `yaml:"port"` + Schema string `yaml:"schema"` + User string `yaml:"user"` + Password string `yaml:"password"` + SSLMode bool `yaml:"sslmode"` + AppName string `yaml:"app_name"` + Charset string `yaml:"charset"` + MaxIdleConn int `yaml:"maxidleconn"` + MaxOpenConn int `yaml:"maxopenconn"` + MaxConnTTL time.Duration `yaml:"maxconnttl"` + Timeout time.Duration `yaml:"timeout"` + OtherParams string `yaml:"other_params"` + } + + pool struct { + conf schema.ConfigInterface + db map[string]*sql.DB + l sync.RWMutex + } +) + +// List getting all configs +func (c *Config) List() (list []schema.ItemInterface) { + for _, item := range c.Pool { + list = append(list, item) + } + return +} + +// GetName getting config name +func (i Item) GetName() string { + return i.Name +} + +// Setup setting config conntections params +func (i Item) Setup(s schema.SetupInterface) { + s.SetMaxIdleConns(i.MaxIdleConn) + s.SetMaxOpenConns(i.MaxOpenConn) + s.SetConnMaxLifetime(i.MaxConnTTL) +} + +// GetDSN connection params +func (i Item) GetDSN() string { + params, err := url.ParseQuery(i.OtherParams) + if err != nil { + params = url.Values{} + } + + //--- + if len(i.Charset) == 0 { + i.Charset = "UTF8" + } + params.Add("client_encoding", i.Charset) + //--- + if i.SSLMode { + params.Add("sslmode", "prefer") + } else { + params.Add("sslmode", "disable") + } + //--- + if i.Timeout == 0 { + i.Timeout = defaultTimeoutConn + } + params.Add("connect_timeout", fmt.Sprintf("%.0f", i.Timeout.Seconds())) + //--- + if len(i.AppName) == 0 { + i.AppName = "go_app" + } + params.Add("application_name", i.AppName) + //--- + + //--- + return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?%s", i.User, i.Password, i.Host, i.Port, i.Schema, params.Encode()) +} + +// New init new mysql connection +func New(conf schema.ConfigInterface) schema.Connector { + c := &pool{ + conf: conf, + db: make(map[string]*sql.DB), + } + + return c +} + +// Dialect getting sql dialect +func (p *pool) Dialect() string { + return schema.PgSQLDialect +} + +// Reconnect update connection to database +func (p *pool) Reconnect() error { + if err := p.Close(); err != nil { + return err + } + + p.l.Lock() + defer p.l.Unlock() + + for _, item := range p.conf.List() { + db, err := sql.Open("postgres", item.GetDSN()) + if err != nil { + if er := p.Close(); er != nil { + return errors.Wrap(err, er) + } + return err + } + item.Setup(db) + p.db[item.GetName()] = db + } + return nil +} + +// Close closing connection +func (p *pool) Close() error { + p.l.Lock() + defer p.l.Unlock() + + if len(p.db) > 0 { + for name, db := range p.db { + if err := db.Close(); err != nil { + return err + } + delete(p.db, name) + } + } + return nil +} + +// Pool getting connection pool by name +func (p *pool) Pool(name string) (*sql.DB, error) { + p.l.RLock() + defer p.l.RUnlock() + + db, ok := p.db[name] + if !ok { + return nil, schema.ErrPoolNotFound + } + return db, db.Ping() +} diff --git a/sdk/orm/schema/sqlite/sqlite.go b/sdk/orm/schema/sqlite/sqlite.go new file mode 100644 index 0000000..1ec5013 --- /dev/null +++ b/sdk/orm/schema/sqlite/sqlite.go @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package sqlite + +import ( + "database/sql" + "fmt" + "net/url" + "sync" + + _ "github.com/mattn/go-sqlite3" //nolint: golint + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/orm/schema" +) + +var ( + _ schema.Connector = (*pool)(nil) + _ schema.ConfigInterface = (*Config)(nil) +) + +type ( + //Config pool of configs + Config struct { + Pool []Item `yaml:"sqlite"` + } + + //Item config model + Item struct { + Name string `yaml:"name"` + File string `yaml:"file"` + Cache string `yaml:"cache"` + Mode string `yaml:"mode"` + Journal string `yaml:"journal"` + LockingMode string `yaml:"locking_mode"` + OtherParams string `yaml:"other_params"` + } + + pool struct { + conf schema.ConfigInterface + db map[string]*sql.DB + l sync.RWMutex + } +) + +// List getting all configs +func (c *Config) List() (list []schema.ItemInterface) { + for _, item := range c.Pool { + list = append(list, item) + } + return +} + +// GetName getting config name +func (i Item) GetName() string { return i.Name } + +// GetDSN connection params +func (i Item) GetDSN() string { + params, err := url.ParseQuery(i.OtherParams) + if err != nil { + params = url.Values{} + } + //--- + if len(i.Cache) == 0 { + i.Cache = "private" + } + params.Add("cache", i.Cache) + //--- + if len(i.Mode) == 0 { + i.Mode = "rwc" + } + params.Add("mode", i.Mode) + //--- + if len(i.Journal) == 0 { + i.Journal = "TRUNCATE" + } + params.Add("_journal", i.Journal) + //--- + if len(i.LockingMode) == 0 { + i.LockingMode = "EXCLUSIVE" + } + params.Add("_locking_mode", i.LockingMode) + //-- + return fmt.Sprintf("file:%s?%s", i.File, params.Encode()) +} + +// Setup setting config conntections params +func (i Item) Setup(_ schema.SetupInterface) {} + +// New init new sqlite connection +func New(conf schema.ConfigInterface) schema.Connector { + c := &pool{ + conf: conf, + db: make(map[string]*sql.DB), + } + + return c +} + +// Dialect getting sql dialect +func (p *pool) Dialect() string { + return schema.SQLiteDialect +} + +// Reconnect update connection to database +func (p *pool) Reconnect() error { + if err := p.Close(); err != nil { + return err + } + + p.l.Lock() + defer p.l.Unlock() + + for _, item := range p.conf.List() { + db, err := sql.Open("sqlite3", item.GetDSN()) + if err != nil { + if er := p.Close(); er != nil { + return errors.Wrap(err, er) + } + return err + } + p.db[item.GetName()] = db + } + return nil +} + +// Close closing connection +func (p *pool) Close() error { + p.l.Lock() + defer p.l.Unlock() + + if len(p.db) > 0 { + for _, db := range p.db { + if err := db.Close(); err != nil { + return err + } + } + } + return nil +} + +// Pool getting connection pool by name +func (p *pool) Pool(name string) (*sql.DB, error) { + p.l.RLock() + defer p.l.RUnlock() + + db, ok := p.db[name] + if !ok { + return nil, schema.ErrPoolNotFound + } + return db, db.Ping() +} diff --git a/sdk/orm/stmt.go b/sdk/orm/stmt.go new file mode 100644 index 0000000..6ea7149 --- /dev/null +++ b/sdk/orm/stmt.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package orm + +import ( + "context" + "database/sql" + + "github.com/osspkg/goppy/sdk/errors" +) + +var ( + errInvalidModelPool = errors.New("invalid decoder pool") +) + +type ( + //Stmt statement model + Stmt interface { + Ping() error + CallContext(name string, ctx context.Context, callFunc func(context.Context, *sql.DB) error) error + TxContext(name string, ctx context.Context, callFunc func(context.Context, *sql.Tx) error) error + + ExecContext(name string, ctx context.Context, call func(q Executor)) error + QueryContext(name string, ctx context.Context, call func(q Querier)) error + TransactionContext(name string, ctx context.Context, call func(v Tx)) error + } + + _stmt struct { + name string + db dbPool + opts *options + } + + dbPool interface { + Dialect() string + Pool(string) (*sql.DB, error) + } +) + +// newStmt init new statement +func newStmt(name string, db dbPool, p *options) Stmt { + return &_stmt{ + name: name, + db: db, + opts: p, + } +} diff --git a/sdk/orm/stmt_exec.go b/sdk/orm/stmt_exec.go new file mode 100644 index 0000000..76d2651 --- /dev/null +++ b/sdk/orm/stmt_exec.go @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package orm + +import ( + "context" + "database/sql" + "sync" + + "github.com/osspkg/goppy/sdk/orm/schema" +) + +var poolExec = sync.Pool{New: func() interface{} { return &exec{} }} + +type exec struct { + Q string + P [][]interface{} + B func(rowsAffected, lastInsertId int64) error +} + +func (v *exec) SQL(query string, args ...interface{}) { + v.Q = query + v.Params(args...) +} + +func (v *exec) Params(args ...interface{}) { + if len(args) > 0 { + v.P = append(v.P, args) + } +} +func (v *exec) Bind(call func(rowsAffected, lastInsertId int64) error) { + v.B = call +} + +func (v *exec) Reset() *exec { + v.Q, v.P, v.B = "", nil, nil + return v +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type ( + //Executor interface for generate execute query + Executor interface { + SQL(query string, args ...interface{}) + Params(args ...interface{}) + Bind(call func(rowsAffected, lastInsertId int64) error) + } +) + +// ExecContext ... +func (s *_stmt) ExecContext(name string, ctx context.Context, call func(q Executor)) error { + return s.CallContext(name, ctx, func(ctx context.Context, db *sql.DB) error { + return callExecContext(ctx, db, call, s.db.Dialect()) + }) +} + +func callExecContext(ctx context.Context, db dbGetter, call func(q Executor), dialect string) error { + q, ok := poolExec.Get().(*exec) + if !ok { + return errInvalidModelPool + } + defer poolExec.Put(q.Reset()) + call(q) + if len(q.P) == 0 { + q.P = append(q.P, []interface{}{}) + } + stmt, err := db.PrepareContext(ctx, q.Q) + if err != nil { + return err + } + defer stmt.Close() //nolint: errcheck + var rowsAffected, lastInsertId int64 + for _, params := range q.P { + result, err0 := stmt.ExecContext(ctx, params...) + if err0 != nil { + return err0 + } + rows, err0 := result.RowsAffected() + if err0 != nil { + return err0 + } + rowsAffected += rows + + if dialect != schema.PgSQLDialect { + rows, err0 = result.LastInsertId() + if err0 != nil { + return err0 + } + lastInsertId = rows + } + } + if err = stmt.Close(); err != nil { + return err + } + if q.B == nil { + return nil + } + return q.B(rowsAffected, lastInsertId) +} diff --git a/sdk/orm/stmt_query.go b/sdk/orm/stmt_query.go new file mode 100644 index 0000000..9600cdd --- /dev/null +++ b/sdk/orm/stmt_query.go @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package orm + +import ( + "context" + "database/sql" + "sync" +) + +var poolQuery = sync.Pool{New: func() interface{} { return &query{} }} + +type query struct { + Q string + P []interface{} + B func(bind Scanner) error +} + +func (v *query) SQL(query string, args ...interface{}) { + v.Q, v.P = query, args +} + +func (v *query) Bind(call func(bind Scanner) error) { + v.B = call +} + +func (v *query) Reset() *query { + v.Q, v.P, v.B = "", nil, nil + return v +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type ( + //Scanner interface for bind data + Scanner interface { + Scan(args ...interface{}) error + } + + //Querier interface for generate query + Querier interface { + SQL(query string, args ...interface{}) + Bind(call func(bind Scanner) error) + } +) + +// QueryContext ... +func (s *_stmt) QueryContext(name string, ctx context.Context, call func(q Querier)) error { + return s.CallContext(name, ctx, func(ctx context.Context, db *sql.DB) error { + return callQueryContext(ctx, db, call) + }) +} + +func callQueryContext(ctx context.Context, db dbGetter, call func(q Querier)) error { + q, ok := poolQuery.Get().(*query) + if !ok { + return errInvalidModelPool + } + defer poolQuery.Put(q.Reset()) + + call(q) + + rows, err := db.QueryContext(ctx, q.Q, q.P...) + if err != nil { + return err + } + defer rows.Close() //nolint: errcheck + if q.B != nil { + for rows.Next() { + if err = q.B(rows); err != nil { + return err + } + } + } + if err = rows.Close(); err != nil { + return err + } + if err = rows.Err(); err != nil { + return err + } + return nil +} diff --git a/sdk/orm/stmt_raw.go b/sdk/orm/stmt_raw.go new file mode 100644 index 0000000..a44282c --- /dev/null +++ b/sdk/orm/stmt_raw.go @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package orm + +import ( + "context" + "database/sql" + + "github.com/osspkg/goppy/sdk/errors" +) + +// Ping database ping +func (s *_stmt) Ping() error { + return s.CallContext("ping", context.Background(), func(ctx context.Context, db *sql.DB) error { + return db.PingContext(ctx) + }) +} + +// CallContext basic query execution +func (s *_stmt) CallContext(name string, ctx context.Context, callFunc func(context.Context, *sql.DB) error) error { + pool, err := s.db.Pool(s.name) + if err != nil { + return err + } + + s.opts.Metrics.ExecutionTime(name, func() { err = callFunc(ctx, pool) }) + + return err +} + +// TxContext the basic execution of a query in a transaction +func (s *_stmt) TxContext(name string, ctx context.Context, callFunc func(context.Context, *sql.Tx) error) error { + return s.CallContext(name, ctx, func(ctx context.Context, db *sql.DB) error { + dbx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + + err = callFunc(ctx, dbx) + if err != nil { + return errors.Wrap( + errors.Wrapf(err, "execute tx"), + errors.Wrapf(dbx.Rollback(), "rollback tx"), + ) + } + + return dbx.Commit() + }) +} diff --git a/sdk/orm/stmt_test.go b/sdk/orm/stmt_test.go new file mode 100644 index 0000000..80dca30 --- /dev/null +++ b/sdk/orm/stmt_test.go @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package orm_test + +import ( + "context" + "database/sql" + "os" + "testing" + + "github.com/osspkg/goppy/sdk/orm" + "github.com/osspkg/goppy/sdk/orm/plugins" + "github.com/osspkg/goppy/sdk/orm/schema/sqlite" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestUnit_Stmt(t *testing.T) { + file, err := os.CreateTemp("/tmp", "prefix") + require.NoError(t, err) + defer os.Remove(file.Name()) //nolint: errcheck + + conn := sqlite.New(&sqlite.Config{Pool: []sqlite.Item{{Name: "main", File: file.Name()}}}) + require.NoError(t, conn.Reconnect()) + defer conn.Close() //nolint: errcheck + pool := orm.New(conn, + orm.UsePluginLogger(plugins.StdOutLog), + orm.UsePluginMetric(plugins.StdOutMetric), + ).Pool("main") + + err = pool.CallContext("init", context.Background(), func(ctx context.Context, db *sql.DB) error { + sqls := []string{ + `create table users ( + id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, + name TEXT + );`, + "insert into `users` (`id`, `name`) values (1, 'aaaa');", + "insert into `users` (`id`, `name`) values (2, 'bbbb');", + } + + for _, item := range sqls { + if _, err = db.ExecContext(ctx, item); err != nil { + return err + } + } + return nil + }) + require.NoError(t, err) + + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users` where `id` = ?", 1) + q.Bind(func(bind orm.Scanner) error { + name := "" + assert.NoError(t, bind.Scan(&name)) + assert.Equal(t, "aaaa", name) + return nil + }) + }) + assert.NoError(t, err) + + var result []string + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + assert.NoError(t, bind.Scan(&name)) + result = append(result, name) + return nil + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb"}, result) + + err = pool.ExecContext("", context.Background(), func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(3, "cccc") + e.Params(4, "dddd") + + e.Bind(func(rowsAffected, lastInsertId int64) error { + assert.Equal(t, int64(2), rowsAffected) + assert.Equal(t, int64(4), lastInsertId) + return nil + }) + }) + assert.NoError(t, err) + + var result2 []string + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + err = bind.Scan(&name) + result2 = append(result2, name) + return err + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb", "cccc", "dddd"}, result2) + + var result3 []string + err = pool.TransactionContext("", context.Background(), func(v orm.Tx) { + v.Exec(func(e orm.Executor) { + e.SQL("insert into `users` (`id`, `name`) values (?, ?);") + e.Params(10, "abcd") + e.Params(11, "efgh") + e.Bind(func(rowsAffected, lastInsertId int64) error { + assert.Equal(t, int64(2), rowsAffected) + assert.Equal(t, int64(11), lastInsertId) + return nil + }) + }) + v.Query(func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + err = bind.Scan(&name) + result3 = append(result3, name) + return err + }) + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb", "cccc", "dddd", "abcd", "efgh"}, result3) + + var result4 []string + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + q.SQL("select `name` from `users`") + q.Bind(func(bind orm.Scanner) error { + name := "" + err = bind.Scan(&name) + result4 = append(result4, name) + return err + }) + }) + assert.NoError(t, err) + assert.Equal(t, []string{"aaaa", "bbbb", "cccc", "dddd", "abcd", "efgh"}, result4) +} + +func BenchmarkStmt(b *testing.B) { + file, err := os.CreateTemp("/tmp", "prefix") + require.NoError(b, err) + defer os.Remove(file.Name()) //nolint: errcheck + + conn := sqlite.New(&sqlite.Config{Pool: []sqlite.Item{{Name: "main", File: file.Name()}}}) + require.NoError(b, conn.Reconnect()) + defer conn.Close() //nolint: errcheck + pool := orm.New(conn).Pool("main") + + err = pool.CallContext("init", context.Background(), func(ctx context.Context, db *sql.DB) error { + sqls := []string{ + `create table users ( + id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE, + name TEXT + );`, + } + + for _, item := range sqls { + if _, err = db.ExecContext(ctx, item); err != nil { + return err + } + } + return nil + }) + require.NoError(b, err) + + b.Run("insert", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 1; i < b.N; i++ { + err = pool.ExecContext("", context.Background(), func(e orm.Executor) { + i := i + e.SQL("insert or ignore into `users` (`id`, `name`) values (?, ?);") + e.Params(i, "cccc") + }) + assert.NoError(b, err) + } + }) + + var name string + b.Run("select", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 1; i < b.N; i++ { + err = pool.QueryContext("", context.Background(), func(q orm.Querier) { + i := i + q.SQL("select `name` from `users` where `id` = ?", i) + q.Bind(func(bind orm.Scanner) error { + return bind.Scan(&name) + }) + }) + assert.NoError(b, err) + } + }) +} diff --git a/sdk/orm/stmt_tx.go b/sdk/orm/stmt_tx.go new file mode 100644 index 0000000..d8e3414 --- /dev/null +++ b/sdk/orm/stmt_tx.go @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package orm + +import ( + "context" + "database/sql" + "fmt" + "sync" +) + +var poolTx = sync.Pool{New: func() interface{} { return &tx{} }} + +type ( + Tx interface { + Exec(vv ...func(e Executor)) + Query(vv ...func(q Querier)) + } + + tx struct { + v []interface{} + } + + dbGetter interface { + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + } +) + +func (v *tx) Exec(vv ...func(q Executor)) { + for _, f := range vv { + v.v = append(v.v, f) + } +} + +func (v *tx) Query(vv ...func(q Querier)) { + for _, f := range vv { + v.v = append(v.v, f) + } +} + +func (v *tx) Reset() *tx { + v.v = v.v[:0] + return v +} + +func (s *_stmt) TransactionContext(name string, ctx context.Context, call func(v Tx)) error { + q, ok := poolTx.Get().(*tx) + if !ok { + return errInvalidModelPool + } + defer poolTx.Put(q.Reset()) + + call(q) + + return s.TxContext(name, ctx, func(ctx context.Context, tx *sql.Tx) error { + for i, c := range q.v { + if cc, ok := c.(func(q Executor)); ok { + if err := callExecContext(ctx, tx, cc, s.db.Dialect()); err != nil { + return err + } + continue + } + if cc, ok := c.(func(q Querier)); ok { + if err := callQueryContext(ctx, tx, cc); err != nil { + return err + } + continue + } + return fmt.Errorf("unknown query model #%d", i) + } + return nil + }) +} diff --git a/sdk/random/random.go b/sdk/random/random.go new file mode 100644 index 0000000..a46ae1d --- /dev/null +++ b/sdk/random/random.go @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package random + +import ( + "math/rand" + "time" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +var ( + digest = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-+=~*@#$%&?!<>") +) + +func BytesOf(n int, src []byte) []byte { + tmp := make([]byte, len(src)) + copy(tmp, src) + rand.Shuffle(len(tmp), func(i, j int) { + tmp[i], tmp[j] = tmp[j], tmp[i] + }) + b := make([]byte, n) + for i := range b { + b[i] = tmp[rand.Intn(len(tmp))] + } + return b +} + +func StringOf(n int, src string) string { + return string(BytesOf(n, []byte(src))) +} + +func Bytes(n int) []byte { + return BytesOf(n, digest) +} + +func String(n int) string { + return string(Bytes(n)) +} diff --git a/sdk/random/random_test.go b/sdk/random/random_test.go new file mode 100644 index 0000000..4971392 --- /dev/null +++ b/sdk/random/random_test.go @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package random_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/osspkg/goppy/sdk/random" +) + +func TestUnit_Bytes(t *testing.T) { + max := 10 + r1 := random.Bytes(max) + r2 := random.Bytes(max) + + fmt.Println(string(r1), string(r2)) + + if len(r1) != max || len(r2) != max { + t.Errorf("invalid len, is not %d", max) + } + if bytes.Equal(r1, r2) { + t.Errorf("result is not random") + } +} + +func TestUnit_BytesOf(t *testing.T) { + max := 10 + src := []byte("1234567890") + r1 := random.BytesOf(max, src) + r2 := random.BytesOf(max, src) + + fmt.Println(string(r1), string(r2)) + + if len(r1) != max || len(r2) != max { + t.Errorf("invalid len, is not %d", max) + } + if bytes.Equal(r1, r2) { + t.Errorf("result is not random") + } +} + +func Benchmark_Bytes64(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + random.Bytes(64) + } +} + +func Benchmark_Bytes256(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + random.Bytes(256) + } +} diff --git a/sdk/routine/routine.go b/sdk/routine/routine.go new file mode 100644 index 0000000..a933500 --- /dev/null +++ b/sdk/routine/routine.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package routine + +import ( + "context" + "time" + + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/iosync" +) + +func Interval(ctx context.Context, interval time.Duration, call func(context.Context)) { + call(ctx) + + go func() { + tick := time.NewTicker(interval) + defer tick.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-tick.C: + call(ctx) + } + } + }() +} + +func Retry(count int, ttl time.Duration, call func() error) error { + var err error + for i := 0; i < count; i++ { + if e := call(); e != nil { + err = errors.Wrap(err, errors.Wrapf(e, "[#%d]", i)) + time.Sleep(ttl) + continue + } + return nil + } + return errors.Wrapf(err, "retry error") +} + +func Parallel(calls ...func()) { + wg := iosync.NewGroup() + for _, call := range calls { + call := call + wg.Background(func() { + call() + }) + } + wg.Wait() +} diff --git a/sdk/routine/routine_test.go b/sdk/routine/routine_test.go new file mode 100644 index 0000000..14bd500 --- /dev/null +++ b/sdk/routine/routine_test.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package routine_test + +import ( + "fmt" + "testing" + + "github.com/osspkg/goppy/sdk/routine" +) + +func TestUnit_Parallel(t *testing.T) { + routine.Parallel( + func() { + fmt.Println("a") + }, func() { + fmt.Println("b") + }, func() { + fmt.Println("c") + }, + ) +} diff --git a/sdk/shell/shell.go b/sdk/shell/shell.go new file mode 100644 index 0000000..b45a52c --- /dev/null +++ b/sdk/shell/shell.go @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package shell + +import ( + "context" + "io" + "os" + "os/exec" + "sync" + + "github.com/osspkg/goppy/sdk/errors" +) + +type ( + sh struct { + env []string + dir string + shell string + mux sync.RWMutex + w io.Writer + ch chan []byte + } + + Shell interface { + Close() + SetEnv(key, value string) + SetDir(dir string) + SetShell(shell string) + SetWriter(w io.Writer) + CallPackageContext(ctx context.Context, commands ...string) error + CallContext(ctx context.Context, command string) error + Call(ctx context.Context, command string) ([]byte, error) + } +) + +func New() Shell { + v := &sh{ + env: make([]string, 0), + dir: os.TempDir(), + shell: "/bin/sh", + w: &NullWriter{}, + ch: make(chan []byte, 128), + } + go v.Pipe() + return v +} + +func (v *sh) SetEnv(key, value string) { + v.mux.Lock() + defer v.mux.Unlock() + + v.env = append(v.env, key+"="+value) +} + +func (v *sh) SetDir(dir string) { + v.mux.Lock() + defer v.mux.Unlock() + + v.dir = dir +} + +func (v *sh) SetShell(shell string) { + v.mux.Lock() + defer v.mux.Unlock() + + v.shell = shell +} + +func (v *sh) SetWriter(w io.Writer) { + v.mux.Lock() + defer v.mux.Unlock() + + v.w = w +} + +func (v *sh) Close() { + v.SetWriter(&NullWriter{}) + close(v.ch) +} + +func (v *sh) Pipe() { + for { + b, ok := <-v.ch + if !ok { + return + } + bb := make([]byte, len(b)) + copy(bb, b) + v.mux.RLock() + v.w.Write(bb) //nolint:errcheck + v.mux.RUnlock() + } +} + +func (v *sh) Write(b []byte) (n int, err error) { + l := len(b) + select { + case v.ch <- b: + default: + } + return l, nil +} + +func (v *sh) CallPackageContext(ctx context.Context, commands ...string) error { + for i, command := range commands { + if err := v.CallContext(ctx, command); err != nil { + return errors.Wrapf(err, "call command #%d [%s]", i, command) + } + } + return nil +} + +func (v *sh) CallContext(ctx context.Context, c string) error { + v.mux.RLock() + cmd := exec.CommandContext(ctx, v.shell, "-xec", c, " <&-") + cmd.Env = append(os.Environ(), v.env...) + cmd.Dir = v.dir + cmd.Stdout = v + cmd.Stderr = v + v.mux.RUnlock() + + return cmd.Run() +} + +func (v *sh) Call(ctx context.Context, c string) ([]byte, error) { + v.mux.RLock() + cmd := exec.CommandContext(ctx, v.shell, "-xec", c, " <&-") + cmd.Env = append(os.Environ(), v.env...) + cmd.Dir = v.dir + v.mux.RUnlock() + + return cmd.CombinedOutput() +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +type NullWriter struct { +} + +func (v *NullWriter) Write(b []byte) (int, error) { + return len(b), nil +} diff --git a/sdk/shell/shell_test.go b/sdk/shell/shell_test.go new file mode 100644 index 0000000..0ff1064 --- /dev/null +++ b/sdk/shell/shell_test.go @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package shell_test + +import ( + "context" + "testing" + + "github.com/osspkg/goppy/sdk/shell" +) + +func TestUnit_ShellCall(t *testing.T) { + sh := shell.New() + sh.SetDir("/tmp") + sh.SetEnv("LANG", "en_US.UTF-8") + + out, err := sh.Call(context.TODO(), "ls -la /tmp") + if err != nil { + t.Fatalf(err.Error()) + } + t.Log(string(out)) +} diff --git a/sdk/syscall/system.go b/sdk/syscall/system.go new file mode 100644 index 0000000..55f8de2 --- /dev/null +++ b/sdk/syscall/system.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package syscall + +import ( + "os" + "os/signal" + "strconv" + scall "syscall" +) + +// OnStop calling a function if you send a system event stop +func OnStop(callFunc func()) { + quit := make(chan os.Signal, 4) + signal.Notify(quit, os.Interrupt, scall.SIGINT, scall.SIGTERM, scall.SIGKILL) //nolint:staticcheck + <-quit + + callFunc() +} + +// OnUp calling a function if you send a system event SIGHUP +func OnUp(callFunc func()) { + quit := make(chan os.Signal, 1) + signal.Notify(quit, scall.SIGHUP) + <-quit + + callFunc() +} + +// OnCustom calling a function if you send a system custom event +func OnCustom(callFunc func(), sig ...os.Signal) { + quit := make(chan os.Signal, 1) + signal.Notify(quit, sig...) + <-quit + + callFunc() +} + +// Pid write pid file +func Pid(filename string) error { + pid := strconv.Itoa(scall.Getpid()) + return os.WriteFile(filename, []byte(pid), 0755) +} diff --git a/sdk/webutil/client.go b/sdk/webutil/client.go new file mode 100644 index 0000000..6f9b97e --- /dev/null +++ b/sdk/webutil/client.go @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "time" + + "github.com/osspkg/goppy/sdk/ioutil" + "github.com/osspkg/goppy/sdk/webutil/signature" +) + +type ( + ClientHttp struct { + cli *http.Client + + headers http.Header + signStore signature.Storage + + enc func(in interface{}) (body []byte, contentType string, err error) + dec func(code int, contentType string, body []byte, out interface{}) error + } +) + +func NewClientHttp(opt ...ClientHttpOption) *ClientHttp { + cli := &ClientHttp{ + cli: http.DefaultClient, + headers: make(http.Header), + } + ClientHttpOptionSetup("env", 5*time.Second, 100)(cli) + ClientHttpOptionCodecDefault()(cli) + for _, option := range opt { + option(cli) + } + return cli +} + +func (v *ClientHttp) Call(ctx context.Context, method, uri string, in interface{}, out interface{}) error { + var ( + ioBody io.Reader + b []byte + contentType string + err error + u *url.URL + ) + + if u, err = url.Parse(uri); err != nil { + return err + } + + if in != nil { + if b, contentType, err = v.enc(in); err != nil { + return err + } + ioBody = bytes.NewReader(b) + } + + req, err := http.NewRequestWithContext(ctx, method, uri, ioBody) + if err != nil { + return err + } + + req.Header.Set("Connection", "keep-alive") + for k := range v.headers { + req.Header.Set(k, v.headers.Get(k)) + } + if len(contentType) > 0 { + req.Header.Set("Content-Type", contentType) + } + + if v.signStore != nil { + if sign := v.signStore.Get(u.Host); sign != nil { + signature.Encode(req.Header, sign, b) + } + } + + resp, err := v.cli.Do(req) //nolint: bodyclose + if err != nil { + return err + } + + b, err = ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + + return v.dec(resp.StatusCode, resp.Header.Get("Content-Type"), b, out) +} + +/**********************************************************************************************************************/ + +type ClientHttpOption func(c *ClientHttp) + +func ClientHttpOptionCodec( + enc func(in interface{}) (body []byte, contentType string, err error), + dec func(code int, contentType string, body []byte, out interface{}) error, +) ClientHttpOption { + return func(c *ClientHttp) { + c.enc = enc + c.dec = dec + } +} + +func ClientHttpOptionCodecDefault() ClientHttpOption { + return ClientHttpOptionCodec( + func(in interface{}) (body []byte, contentType string, err error) { + switch v := in.(type) { + case []byte: + return v, "", nil + case json.Marshaler: + body, err = v.MarshalJSON() + return body, "application/json; charset=utf-8", err + default: + return nil, "", fmt.Errorf("unknown request format %T", in) + } + }, + func(code int, _ string, body []byte, out interface{}) error { + switch code { + case 200: + switch v := out.(type) { + case *[]byte: + *v = append(*v, body...) + return nil + case json.Unmarshaler: + return v.UnmarshalJSON(body) + default: + return fmt.Errorf("unknown response format %T", out) + } + + default: + return fmt.Errorf("%d %s", code, http.StatusText(code)) + } + }, + ) +} + +func ClientHttpOptionSetup(proxy string, ttl time.Duration, countConn int) ClientHttpOption { + return func(c *ClientHttp) { + c.cli.Timeout = ttl + dial := &net.Dialer{ + Timeout: 15 * ttl, + KeepAlive: 15 * ttl, + } + c.cli.Transport = &http.Transport{ + Proxy: proxySetup(proxy), + DialContext: dial.DialContext, + MaxIdleConns: countConn, + MaxIdleConnsPerHost: countConn, + } + } +} + +func ClientHttpOptionHeaders(keyVal ...string) ClientHttpOption { + if len(keyVal)%2 != 0 { + keyVal = append(keyVal, "") + } + return func(c *ClientHttp) { + for i := 0; i < len(keyVal); i += 2 { + c.headers.Set(keyVal[i], keyVal[i+1]) + } + } +} + +func ClientHttpOptionAuth(s signature.Storage) ClientHttpOption { + return func(c *ClientHttp) { + c.signStore = s + } +} + +func proxySetup(proxy string) func(r *http.Request) (*url.URL, error) { + if len(proxy) == 0 || proxy == "env" { + return http.ProxyFromEnvironment + } + u, err := url.Parse(proxy) + if err != nil { + return func(r *http.Request) (*url.URL, error) { + return nil, err + } + } + return http.ProxyURL(u) +} diff --git a/sdk/webutil/client_test.go b/sdk/webutil/client_test.go new file mode 100644 index 0000000..d2d176e --- /dev/null +++ b/sdk/webutil/client_test.go @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/osspkg/goppy/sdk/webutil" + "github.com/stretchr/testify/require" +) + +type ( + TestModel struct { + Val struct { + Page struct { + Name string `json:"name"` + } `json:"page"` + } + } +) + +func (v *TestModel) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, &v.Val) +} + +func TestUnit_NewClientHttp_JSON(t *testing.T) { + model := TestModel{} + cli := webutil.NewClientHttp() + err := cli.Call(context.TODO(), http.MethodGet, "https://www.githubstatus.com/api/v2/status.json", nil, &model) + require.NoError(t, err) + require.Equal(t, "GitHub", model.Val.Page.Name) +} + +func TestUnit_NewClientHttp_Bytes(t *testing.T) { + var model []byte + cli := webutil.NewClientHttp() + err := cli.Call(context.TODO(), http.MethodGet, "https://www.githubstatus.com/api/v2/status.json", nil, &model) + require.NoError(t, err) + require.Contains(t, string(model), ",\"name\":\"GitHub\",") +} diff --git a/sdk/webutil/codec.go b/sdk/webutil/codec.go new file mode 100644 index 0000000..e87b44c --- /dev/null +++ b/sdk/webutil/codec.go @@ -0,0 +1,72 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + + "github.com/osspkg/goppy/sdk/ioutil" +) + +func JSONEncode(w http.ResponseWriter, v interface{}) { + b, err := json.Marshal(v) + if err != nil { + ErrorEncode(w, err) + return + } + w.Header().Add("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(b) //nolint: errcheck +} + +func JSONDecode(r *http.Request, v interface{}) error { + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + return json.Unmarshal(b, v) +} + +func XMLEncode(w http.ResponseWriter, v interface{}) { + b, err := xml.Marshal(v) + if err != nil { + ErrorEncode(w, err) + return + } + w.Header().Add("Content-Type", "application/xml; charset=utf-8") + w.WriteHeader(http.StatusOK) + w.Write(b) //nolint: errcheck +} + +func XMLDecode(r *http.Request, v interface{}) error { + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return err + } + return xml.Unmarshal(b, v) +} + +func ErrorEncode(w http.ResponseWriter, v error) { + w.Header().Add("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(v.Error())) //nolint: errcheck +} + +func StreamEncode(w http.ResponseWriter, v []byte, filename string) { + w.Header().Add("Content-Type", "application/octet-stream") + w.Header().Add("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, filename)) + w.WriteHeader(http.StatusOK) + w.Write(v) //nolint: errcheck +} + +func RawEncode(w http.ResponseWriter, v []byte) { + w.Header().Add("Content-Type", http.DetectContentType(v)) + w.WriteHeader(http.StatusOK) + w.Write(v) //nolint: errcheck +} diff --git a/sdk/webutil/common.go b/sdk/webutil/common.go new file mode 100644 index 0000000..eb9ae74 --- /dev/null +++ b/sdk/webutil/common.go @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "strings" + "time" + + "github.com/osspkg/goppy/sdk/errors" +) + +const ( + defaultTimeout = 10 * time.Second + defaultShutdownTimeout = 1 * time.Second + defaultNetwork = "tcp" +) + +var ( + errServAlreadyRunning = errors.New("server already running") + errServAlreadyStopped = errors.New("server already stopped") + errFailContextKey = errors.New("context key is not found") +) + +var ( + networkType = map[string]struct{}{ + "tcp": {}, + "tcp4": {}, + "tcp6": {}, + "unixpacket": {}, + "unix": {}, + } +) + +/**********************************************************************************************************************/ + +const urlSplitSeparate = "/" + +func urlSplit(uri string) []string { + vv := strings.Split(strings.ToLower(uri), urlSplitSeparate) + for i := 0; i < len(vv); i++ { + if len(vv[i]) == 0 { + copy(vv[i:], vv[i+1:]) + vv = vv[:len(vv)-1] + i-- + } + } + return vv +} diff --git a/sdk/webutil/common_test.go b/sdk/webutil/common_test.go new file mode 100644 index 0000000..0d3110d --- /dev/null +++ b/sdk/webutil/common_test.go @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "reflect" + "testing" +) + +func TestUnit_urlSplit(t *testing.T) { + type args struct { + uri string + } + tests := []struct { + name string + args args + want []string + }{ + {name: "Case1", args: args{uri: ""}, want: []string{}}, + {name: "Case2", args: args{uri: "/a/b/"}, want: []string{"a", "b"}}, + {name: "Case3", args: args{uri: "/a/////b/"}, want: []string{"a", "b"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := urlSplit(tt.args.uri); !reflect.DeepEqual(got, tt.want) { + t.Errorf("split() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sdk/webutil/route.go b/sdk/webutil/route.go new file mode 100644 index 0000000..a39a971 --- /dev/null +++ b/sdk/webutil/route.go @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "context" + "net/http" + "sync" +) + +var _ http.Handler = (*Router)(nil) + +// Router model +type Router struct { + handler *ctrlHandler + lock sync.RWMutex +} + +// NewRouter init new router +func NewRouter() *Router { + return &Router{ + handler: newCtrlHandler(), + } +} + +// Route add new route +func (v *Router) Route(path string, ctrl func(http.ResponseWriter, *http.Request), methods ...string) { + v.lock.Lock() + v.handler.Route(path, ctrl, methods) + v.lock.Unlock() +} + +// Global add global middlewares +func (v *Router) Global( + middlewares ...func(func(http.ResponseWriter, *http.Request), + ) func(http.ResponseWriter, *http.Request)) { + v.lock.Lock() + v.handler.Middlewares("", middlewares...) + v.lock.Unlock() +} + +// Middlewares add middlewares to route +func (v *Router) Middlewares( + path string, middlewares ...func(func(http.ResponseWriter, *http.Request), + ) func(http.ResponseWriter, *http.Request)) { + v.lock.Lock() + v.handler.Middlewares(path, middlewares...) + v.lock.Unlock() +} + +// NoFoundHandler ctrlHandler call if route not found +func (v *Router) NoFoundHandler(call func(http.ResponseWriter, *http.Request)) { + v.lock.Lock() + v.handler.NoFoundHandler(call) + v.lock.Unlock() +} + +// ServeHTTP http interface +func (v *Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { + v.lock.RLock() + defer v.lock.RUnlock() + + code, next, params, midd := v.handler.Match(r.URL.Path, r.Method) + if code != http.StatusOK { + next = codeHandler(code) + } + + ctx := r.Context() + for key, val := range params { + ctx = context.WithValue(ctx, uriParamKey(key), val) + } + + for i := len(midd) - 1; i >= 0; i-- { + next = midd[i](next) + } + next(w, r.WithContext(ctx)) +} + +func codeHandler(code int) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + } +} diff --git a/sdk/webutil/route_handler.go b/sdk/webutil/route_handler.go new file mode 100644 index 0000000..bfb12c5 --- /dev/null +++ b/sdk/webutil/route_handler.go @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "net/http" + "strings" +) + +const anyPath = "#" + +type ctrlHandler struct { + list map[string]*ctrlHandler + methods map[string]func(http.ResponseWriter, *http.Request) + matcher *paramMatch + middlewares []func(func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) + notFound func(http.ResponseWriter, *http.Request) +} + +func newCtrlHandler() *ctrlHandler { + return &ctrlHandler{ + list: make(map[string]*ctrlHandler), + methods: make(map[string]func(http.ResponseWriter, *http.Request)), + matcher: newParamMatch(), + middlewares: make([]func(func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request), 0), + } +} + +func (v *ctrlHandler) append(path string) *ctrlHandler { + if uh, ok := v.list[path]; ok { + return uh + } + uh := newCtrlHandler() + v.list[path] = uh + return uh +} + +func (v *ctrlHandler) next(path string, vars uriParamData) (*ctrlHandler, bool) { + if uh, ok := v.list[path]; ok { + return uh, false + } + if uri, ok := v.matcher.Match(path, vars); ok { + if uh, ok1 := v.list[uri]; ok1 { + return uh, false + } + } + if uh, ok := v.list[anyPath]; ok { + return uh, true + } + return nil, false +} + +// Route add new route +func (v *ctrlHandler) Route(path string, ctrl func(http.ResponseWriter, *http.Request), methods []string) { + uh := v + uris := urlSplit(path) + for _, uri := range uris { + if hasParamMatch(uri) { + if err := uh.matcher.Add(uri); err != nil { + panic(err) + } + } + uh = uh.append(uri) + } + for _, m := range methods { + uh.methods[strings.ToUpper(m)] = ctrl + } +} + +// Middlewares add middleware to route +func (v *ctrlHandler) Middlewares( + path string, middlewares ...func(func(http.ResponseWriter, *http.Request), + ) func(http.ResponseWriter, *http.Request)) { + uh := v + uris := urlSplit(path) + for _, uri := range uris { + uh = uh.append(uri) + } + uh.middlewares = append(uh.middlewares, middlewares...) +} + +func (v *ctrlHandler) NoFoundHandler(call func(http.ResponseWriter, *http.Request)) { + v.notFound = call +} + +// Match find route in tree +func (v *ctrlHandler) Match(path string, method string) ( + int, func(http.ResponseWriter, *http.Request), uriParamData, []func(func(http.ResponseWriter, *http.Request), + ) func(http.ResponseWriter, *http.Request)) { + uh := v + uris := urlSplit(path) + midd := append(make([]func(func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request), + 0, len(uh.middlewares)), uh.middlewares...) + vr := uriParamData{} + var isBreak bool + for _, uri := range uris { + if uh, isBreak = uh.next(uri, vr); uh != nil { + midd = append(midd, uh.middlewares...) + if isBreak { + break + } + continue + } + if v.notFound != nil { + return http.StatusOK, v.notFound, nil, midd + } + return http.StatusNotFound, nil, nil, v.middlewares + } + if ctrl, ok := uh.methods[method]; ok { + return http.StatusOK, ctrl, vr, midd + } + if v.notFound != nil { + return http.StatusOK, v.notFound, nil, midd + } + if len(uh.methods) == 0 { + return http.StatusNotFound, nil, nil, v.middlewares + } + return http.StatusMethodNotAllowed, nil, nil, v.middlewares +} diff --git a/sdk/webutil/route_handler_test.go b/sdk/webutil/route_handler_test.go new file mode 100644 index 0000000..f2a1118 --- /dev/null +++ b/sdk/webutil/route_handler_test.go @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnit_NewHandler(t *testing.T) { + h := newCtrlHandler() + h.Route("/aaa/{id}", func(_ http.ResponseWriter, _ *http.Request) {}, []string{http.MethodPost}) + h.Route("", func(_ http.ResponseWriter, _ *http.Request) {}, []string{http.MethodPost}) + + code, ctrl, vr, midd := h.Match("/aaa/bbb", http.MethodPost) + require.Equal(t, 200, code) + require.NotNil(t, ctrl) + require.Equal(t, 0, len(midd)) + require.Equal(t, uriParamData{"id": "bbb"}, vr) + + h.Middlewares("/aaa", RecoveryMiddleware(nil)) + h.Middlewares("", RecoveryMiddleware(nil)) + + code, ctrl, vr, midd = h.Match("/aaa/ccc", http.MethodGet) + require.Equal(t, http.StatusMethodNotAllowed, code) + require.Nil(t, ctrl) + require.Equal(t, 1, len(midd)) + require.Equal(t, uriParamData(nil), vr) + + code, ctrl, vr, midd = h.Match("/aaa/bbb", http.MethodPost) + require.Equal(t, http.StatusOK, code) + require.NotNil(t, ctrl) + require.Equal(t, 2, len(midd)) + require.Equal(t, uriParamData{"id": "bbb"}, vr) + + code, ctrl, vr, midd = h.Match("", http.MethodPost) + require.Equal(t, http.StatusOK, code) + require.NotNil(t, ctrl) + require.Equal(t, 1, len(midd)) + require.Equal(t, uriParamData{}, vr) + + h.Middlewares("/www/www/www", RecoveryMiddleware(nil)) + + code, ctrl, vr, midd = h.Match("/www/www/www", http.MethodPost) + require.Equal(t, http.StatusNotFound, code) + require.Nil(t, ctrl) + require.Equal(t, 1, len(midd)) + require.Equal(t, uriParamData(nil), vr) + + code, ctrl, vr, midd = h.Match("/test", http.MethodGet) + require.Equal(t, http.StatusNotFound, code) + require.Nil(t, ctrl) + require.Equal(t, 1, len(midd)) + require.Equal(t, uriParamData(nil), vr) + + h.NoFoundHandler(func(_ http.ResponseWriter, _ *http.Request) {}) + + code, ctrl, vr, midd = h.Match("/test", http.MethodGet) + require.Equal(t, http.StatusOK, code) + require.NotNil(t, ctrl) + require.Equal(t, 1, len(midd)) + require.Equal(t, uriParamData(nil), vr) +} + +func TestUnit_NewHandler2(t *testing.T) { + h := newCtrlHandler() + h.Route("/api/v{id}/data/#", func(_ http.ResponseWriter, _ *http.Request) {}, []string{http.MethodGet}) + + h.Middlewares("/api/v{id}", RecoveryMiddleware(nil)) + + code, ctrl, vr, midd := h.Match("/api/v1/data/user/aaaa", http.MethodGet) + require.Equal(t, http.StatusOK, code) + require.NotNil(t, ctrl) + require.Equal(t, 1, len(midd)) + require.Equal(t, uriParamData{"id": "1"}, vr) + +} diff --git a/sdk/webutil/route_middleware.go b/sdk/webutil/route_middleware.go new file mode 100644 index 0000000..b3c12c0 --- /dev/null +++ b/sdk/webutil/route_middleware.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "net/http" + + "github.com/osspkg/goppy/sdk/log" +) + +// RecoveryMiddleware recovery go panic and write to log +func RecoveryMiddleware(l log.Logger) func( + func(http.ResponseWriter, *http.Request), +) func(http.ResponseWriter, *http.Request) { + return func(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + if l != nil { + l.WithFields(log.Fields{"err": err}).Errorf("Recovered") + } + w.WriteHeader(http.StatusInternalServerError) + } + }() + f(w, r) + } + } +} diff --git a/sdk/webutil/route_param.go b/sdk/webutil/route_param.go new file mode 100644 index 0000000..119889c --- /dev/null +++ b/sdk/webutil/route_param.go @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "fmt" + "net/http" + "regexp" + "strconv" + "strings" +) + +var uriParamRex = regexp.MustCompile(`\{([a-z0-9]+)\:?([^{}]*)\}`) + +type paramMatch struct { + incr int + keys map[string]string + links map[string]string + pattern string + rex *regexp.Regexp +} + +func newParamMatch() *paramMatch { + return ¶mMatch{ + incr: 1, + pattern: "", + keys: make(map[string]string), + links: make(map[string]string), + } +} + +func (v *paramMatch) Add(vv string) error { + result := vv + + patterns := uriParamRex.FindAllString(vv, -1) + for _, pattern := range patterns { + res := uriParamRex.FindAllStringSubmatch(pattern, 1)[0] + + key := fmt.Sprintf("k%d", v.incr) + rex := ".+" + if len(res) == 3 && len(res[2]) > 0 { + rex = res[2] + } + result = strings.Replace(result, res[0], fmt.Sprintf("(?P<%s>%s)", key, rex), 1) + + v.links[key] = vv + v.keys[key] = res[1] + v.incr++ + } + + result = "^" + result + "$" + + if _, err := regexp.Compile(result); err != nil { + return fmt.Errorf("regex compilation error for `%s`: %w", vv, err) + } + + if len(v.pattern) != 0 { + v.pattern += "|" + } + v.pattern += result + v.rex = regexp.MustCompile(v.pattern) + return nil +} + +func (v *paramMatch) Match(vv string, vr uriParamData) (string, bool) { + if v.rex == nil { + return "", false + } + + matches := v.rex.FindStringSubmatch(vv) + if len(matches) == 0 { + return "", false + } + + link := "" + for indx, name := range v.rex.SubexpNames() { + val := matches[indx] + if len(val) == 0 { + continue + } + if l, ok := v.links[name]; ok { + link = l + } + if key, ok := v.keys[name]; ok { + vr[key] = val + } + } + + return link, true +} + +func hasParamMatch(vv string) bool { + return uriParamRex.MatchString(vv) +} + +/**********************************************************************************************************************/ + +type ( + uriParamKey string + uriParamData map[string]string +) + +func ParamString(r *http.Request, key string) (string, error) { + if v := r.Context().Value(uriParamKey(key)); v != nil { + return v.(string), nil + } + return "", errFailContextKey +} + +func ParamInt(r *http.Request, key string) (int64, error) { + v, err := ParamString(r, key) + if err != nil { + return 0, err + } + return strconv.ParseInt(v, 10, 64) +} + +func ParamFloat(r *http.Request, key string) (float64, error) { + v, err := ParamString(r, key) + if err != nil { + return 0, err + } + return strconv.ParseFloat(v, 64) +} diff --git a/sdk/webutil/route_param_test.go b/sdk/webutil/route_param_test.go new file mode 100644 index 0000000..3fd73aa --- /dev/null +++ b/sdk/webutil/route_param_test.go @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHasMatcher(t *testing.T) { + tests := []struct { + name string + args string + want bool + }{ + {name: "Case1", args: `test-{id:\d+}`, want: true}, + {name: "Case2", args: `test-id:\d+`, want: false}, + {name: "Case3", args: `test-{id}`, want: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := hasParamMatch(tt.args); got != tt.want { + t.Errorf("HasMatcher() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnit_NewMatcher(t *testing.T) { + mt := newParamMatch() + + tests1 := []struct { + name string + args string + wantErr bool + }{ + {name: "c1", args: `page-{id}-{title:[\]}`, wantErr: true}, + {name: "c2", args: `page-{id:\d+}-{title2:[0-9]+}`, wantErr: false}, + {name: "c3", args: `page-{id:\d+}-{title1:[a-zA-Z]+}`, wantErr: false}, + {name: "c4", args: `page-{id:\d+}-{title3:.+}`, wantErr: false}, + {name: "c5", args: `page-{id:\d+}-{title5:.*}`, wantErr: false}, + } + for _, tt := range tests1 { + t.Run(tt.name, func(t *testing.T) { + err := mt.Add(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Matcher.Add() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } + + type args struct { + vv string + vr uriParamData + } + tests2 := []struct { + name string + args args + want string + want1 bool + }{ + { + name: "c6", + args: args{ + vv: "hello", + vr: uriParamData{}, + }, + want: "", + want1: false, + }, + { + name: "c7", + args: args{ + vv: "page--", + vr: uriParamData{}, + }, + want: "", + want1: false, + }, + { + name: "c8", + args: args{ + vv: "page-123-Hello", + vr: uriParamData{"id": "123", "title1": "Hello"}, + }, + want: `page-{id:\d+}-{title1:[a-zA-Z]+}`, + want1: true, + }, + { + name: "c9", + args: args{ + vv: "page-123-0000", + vr: uriParamData{"id": "123", "title2": "0000"}, + }, + want: `page-{id:\d+}-{title2:[0-9]+}`, + want1: true, + }, + { + name: "c10", + args: args{ + vv: "page-123-bb-88", + vr: uriParamData{"id": "123", "title3": "bb-88"}, + }, + want: `page-{id:\d+}-{title3:.+}`, + want1: true, + }, + { + name: "c11", + args: args{ + vv: "page-123-", + vr: uriParamData{"id": "123"}, + }, + want: `page-{id:\d+}-{title5:.*}`, + want1: true, + }, + } + for _, tt := range tests2 { + t.Run(tt.name, func(t *testing.T) { + params := uriParamData{} + got, got1 := mt.Match(tt.args.vv, params) + if got != tt.want { + t.Errorf("Matcher.Match() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Matcher.Match() got1 = %v, want %v", got1, tt.want1) + } + require.Equal(t, tt.args.vr, params, "Matcher.Match() params = %v, want %v", params, tt.args.vr) + }) + } +} + +func TestUnit_NewMatcher1(t *testing.T) { + mt := newParamMatch() + require.NoError(t, mt.Add(`{id}`)) + params := uriParamData{} + path, ok := mt.Match("bbb", params) + require.True(t, ok) + require.Equal(t, `{id}`, path) + require.Equal(t, uriParamData{"id": "bbb"}, params) +} diff --git a/sdk/webutil/route_test.go b/sdk/webutil/route_test.go new file mode 100644 index 0000000..9d57c0d --- /dev/null +++ b/sdk/webutil/route_test.go @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/osspkg/goppy/sdk/webutil" + "github.com/stretchr/testify/require" +) + +func TestUnit_Route1(t *testing.T) { + result := new(string) + r := webutil.NewRouter() + r.Global(func(c func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + *result += "1" + c(w, r) + } + }) + r.Global(func(c func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + *result += "2" + c(w, r) + } + }) + r.Global(func(c func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + *result += "3" + c(w, r) + } + }) + r.Route("/", func(w http.ResponseWriter, r *http.Request) { + *result += "Ctrl" + }, http.MethodGet) + r.Middlewares("/test", func(c func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + *result += "4" + c(w, r) + } + }) + r.Middlewares("/", func(c func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + *result += "5" + c(w, r) + } + }) + + w := httptest.NewRecorder() + r.ServeHTTP(w, httptest.NewRequest("GET", "/", nil)) + require.Equal(t, "1235Ctrl", *result) +} + +type statusInterface interface { + Result() *http.Response +} + +func getStatusAndClose(s statusInterface) int { + resp := s.Result() + code := resp.StatusCode + err := resp.Body.Close() + if err != nil { + fmt.Println(err.Error()) + return -1 + } + return code +} + +func TestUnit_Route2(t *testing.T) { + r := webutil.NewRouter() + r.Route("/{id}", func(w http.ResponseWriter, r *http.Request) {}, http.MethodGet) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/aaa/bbb/ccc/eee/ggg/fff/kkk", nil) + r.ServeHTTP(w, req) + require.Equal(t, 404, getStatusAndClose(w)) + + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/aaa/", nil) + r.ServeHTTP(w, req) + require.Equal(t, 200, getStatusAndClose(w)) + + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/aaa", nil) + r.ServeHTTP(w, req) + require.Equal(t, 200, getStatusAndClose(w)) + + w = httptest.NewRecorder() + req = httptest.NewRequest("GET", "/aaa?a=1", nil) + r.ServeHTTP(w, req) + require.Equal(t, 200, getStatusAndClose(w)) +} + +func mockNilHandler(_ http.ResponseWriter, _ *http.Request) {} + +func BenchmarkRouter0(b *testing.B) { + serv := webutil.NewRouter() + serv.Route(`/aaa/bbb/ccc/eee/ggg/fff/kkk`, mockNilHandler, http.MethodGet) + serv.Route(`/aaa/bbb/000/eee/ggg/fff/kkk`, mockNilHandler, http.MethodGet) + + req := []*http.Request{ + httptest.NewRequest("GET", "/aaa/bbb/ccc/eee/ggg/fff/kkk", nil), + httptest.NewRequest("GET", "/aaa/bbb/000/eee/ggg/fff/kkk", nil), + } + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + w := httptest.NewRecorder() + b.Run("", func(b *testing.B) { + for i := 0; i < b.N; i++ { + serv.ServeHTTP(w, req[i%2]) + code := getStatusAndClose(w) + if code != http.StatusOK { + b.Fatalf("invalid code: %d", code) + } + w.Flush() + } + }) + } + }) +} + +func BenchmarkRouter1(b *testing.B) { + serv := webutil.NewRouter() + serv.Route(`/{id0}/{id1}/{id2:\d+}/{id3}/{id4}/{id5}/{id6}`, mockNilHandler, http.MethodGet) + serv.Route(`/{id0}/{id1}/{id2:\w+}/{id3}/{id4}/{id5}/{id6}`, mockNilHandler, http.MethodGet) + + req := []*http.Request{ + httptest.NewRequest("GET", "/aaa/bbb/ccc/eee/ggg/fff/kkk", nil), + httptest.NewRequest("GET", "/aaa/bbb/000/eee/ggg/fff/kkk", nil), + } + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + w := httptest.NewRecorder() + b.Run("", func(b *testing.B) { + for i := 0; i < b.N; i++ { + serv.ServeHTTP(w, req[i%2]) + code := getStatusAndClose(w) + if code != http.StatusOK { + b.Fatalf("invalid code: %d", code) + } + w.Flush() + } + }) + } + }) +} diff --git a/sdk/webutil/server_debug.go b/sdk/webutil/server_debug.go new file mode 100755 index 0000000..a72922f --- /dev/null +++ b/sdk/webutil/server_debug.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "net/http" + "net/http/pprof" + + "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/log" +) + +// ServerDebug service model +type ServerDebug struct { + server *ServerHttp + route *Router +} + +// NewServerDebug init debug service +func NewServerDebug(c ConfigHttp, l log.Logger) *ServerDebug { + route := NewRouter() + return &ServerDebug{ + server: NewServerHttp(c, route, l), + route: route, + } +} + +// Up start service +func (o *ServerDebug) Up(ctx app.Context) error { + o.route.Route("/debug/pprof", pprof.Index, http.MethodGet) + o.route.Route("/debug/pprof/goroutine", pprof.Index, http.MethodGet) + o.route.Route("/debug/pprof/allocs", pprof.Index, http.MethodGet) + o.route.Route("/debug/pprof/block", pprof.Index, http.MethodGet) + o.route.Route("/debug/pprof/heap", pprof.Index, http.MethodGet) + o.route.Route("/debug/pprof/mutex", pprof.Index, http.MethodGet) + o.route.Route("/debug/pprof/threadcreate", pprof.Index, http.MethodGet) + o.route.Route("/debug/pprof/cmdline", pprof.Cmdline, http.MethodGet) + o.route.Route("/debug/pprof/profile", pprof.Profile, http.MethodGet) + o.route.Route("/debug/pprof/symbol", pprof.Symbol, http.MethodGet) + o.route.Route("/debug/pprof/trace", pprof.Trace, http.MethodGet) + return o.server.Up(ctx) +} + +// Down stop service +func (o *ServerDebug) Down() error { + return o.server.Down() +} diff --git a/sdk/webutil/server_http.go b/sdk/webutil/server_http.go new file mode 100644 index 0000000..5621f56 --- /dev/null +++ b/sdk/webutil/server_http.go @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package webutil + +import ( + "context" + "net" + "net/http" + "time" + + application "github.com/osspkg/goppy/sdk/app" + "github.com/osspkg/goppy/sdk/errors" + "github.com/osspkg/goppy/sdk/iosync" + "github.com/osspkg/goppy/sdk/log" + "github.com/osspkg/goppy/sdk/netutil" +) + +type ( + ConfigHttp struct { + Addr string `yaml:"addr"` + Network string `yaml:"network,omitempty"` + ReadTimeout time.Duration `yaml:"read_timeout,omitempty"` + WriteTimeout time.Duration `yaml:"write_timeout,omitempty"` + IdleTimeout time.Duration `yaml:"idle_timeout,omitempty"` + ShutdownTimeout time.Duration `yaml:"shutdown_timeout,omitempty"` + } + + ServerHttp struct { + conf ConfigHttp + serv *http.Server + handler http.Handler + + log log.Logger + wg iosync.Group + sync iosync.Switch + } +) + +// NewServerHttp create default http server +func NewServerHttp(conf ConfigHttp, handler http.Handler, l log.Logger) *ServerHttp { + srv := &ServerHttp{ + conf: conf, + handler: handler, + log: l, + sync: iosync.NewSwitch(), + wg: iosync.NewGroup(), + } + srv.validate() + return srv +} + +func (s *ServerHttp) validate() { + if s.conf.ReadTimeout == 0 { + s.conf.ReadTimeout = defaultTimeout + } + if s.conf.WriteTimeout == 0 { + s.conf.WriteTimeout = defaultTimeout + } + if s.conf.IdleTimeout == 0 { + s.conf.IdleTimeout = defaultTimeout + } + if s.conf.ShutdownTimeout == 0 { + s.conf.ShutdownTimeout = defaultShutdownTimeout + } + if len(s.conf.Network) == 0 { + s.conf.Network = defaultNetwork + } + if _, ok := networkType[s.conf.Network]; !ok { + s.conf.Network = defaultNetwork + } + s.conf.Addr = netutil.CheckHostPort(s.conf.Addr) +} + +// Up start http server +func (s *ServerHttp) Up(ctx application.Context) error { + if !s.sync.On() { + return errors.Wrapf(errServAlreadyRunning, "starting server on %s", s.conf.Addr) + } + s.serv = &http.Server{ + ReadTimeout: s.conf.ReadTimeout, + WriteTimeout: s.conf.WriteTimeout, + IdleTimeout: s.conf.IdleTimeout, + Handler: s.handler, + } + + nl, err := net.Listen(s.conf.Network, s.conf.Addr) + if err != nil { + return err + } + + s.log.WithFields(log.Fields{ + "ip": s.conf.Addr, + }).Infof("HTTP server started") + + s.wg.Background(func() { + if err = s.serv.Serve(nl); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.log.WithFields(log.Fields{ + "err": err.Error(), "ip": s.conf.Addr, + }).Errorf("HTTP server stopped") + ctx.Close() + return + } + s.log.WithFields(log.Fields{ + "ip": s.conf.Addr, + }).Infof("HTTP server stopped") + }) + return nil +} + +// Down stop http server +func (s *ServerHttp) Down() error { + if !s.sync.Off() { + return errors.Wrapf(errServAlreadyStopped, "stopping server on %s", s.conf.Addr) + } + ctx, cncl := context.WithTimeout(context.Background(), s.conf.ShutdownTimeout) + defer cncl() + err := s.serv.Shutdown(ctx) + s.wg.Wait() + return err +} diff --git a/sdk/webutil/signature/common.go b/sdk/webutil/signature/common.go new file mode 100644 index 0000000..b1fbf80 --- /dev/null +++ b/sdk/webutil/signature/common.go @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package signature + +import ( + "fmt" + "net/http" + "regexp" + + "github.com/osspkg/goppy/sdk/errors" +) + +const ( + Header = `Signature` + valueRegexp = `keyId=\"(.*)\",algorithm=\"(.*)\",signature=\"(.*)\"` + valueTmpl = `keyId="%s",algorithm="%s",signature="%s"` +) + +var ( + ErrInvalidSignature = errors.New("invalid signature header") + rex = regexp.MustCompile(valueRegexp) +) + +type Data struct { + ID string + Alg string + Hash string +} + +// Decode getting signature from header +func Decode(h http.Header) (s Data, err error) { + d := h.Get(Header) + r := rex.FindSubmatch([]byte(d)) + if len(r) != 4 { + err = ErrInvalidSignature + return + } + s.ID, s.Alg, s.Hash = string(r[1]), string(r[2]), string(r[3]) + return +} + +// Encode make and setting signature to header +func Encode(h http.Header, s Signature, body []byte) { + h.Set(Header, fmt.Sprintf(valueTmpl, s.ID(), s.Algorithm(), s.CreateString(body))) +} diff --git a/sdk/webutil/signature/signer.go b/sdk/webutil/signature/signer.go new file mode 100644 index 0000000..586e7ee --- /dev/null +++ b/sdk/webutil/signature/signer.go @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package signature + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "hash" + "sync" +) + +var _ Signature = (*_sig)(nil) + +type ( + _sig struct { + id string + hashFunc hash.Hash + alg string + lock sync.Mutex + } + + //Signature interface + Signature interface { + ID() string + Algorithm() string + Create(b []byte) []byte + CreateString(b []byte) string + Validate(b []byte, ex string) bool + } +) + +// NewSHA256 create sign sha256 +func NewSHA256(id, secret string) Signature { + return NewCustomSignature(id, secret, "hmac-sha256", sha256.New) +} + +// NewMD5 create sign md5 +func NewMD5(id, secret string) Signature { + return NewCustomSignature(id, secret, "hmac-md5", md5.New) +} + +// NewSHA512 create sign sha512 +func NewSHA512(id, secret string) Signature { + return NewCustomSignature(id, secret, "hmac-sha512", sha512.New) +} + +// NewCustomSignature create sign with custom hash function +func NewCustomSignature(id, secret, alg string, h func() hash.Hash) Signature { + return &_sig{ + id: id, + alg: alg, + hashFunc: hmac.New(h, []byte(secret)), + } +} + +// ID signature +func (s *_sig) ID() string { + return s.id +} + +// Algorithm getter +func (s *_sig) Algorithm() string { + return s.alg +} + +// Create getting hash as bytes +func (s *_sig) Create(b []byte) []byte { + s.lock.Lock() + defer func() { + s.hashFunc.Reset() + s.lock.Unlock() + }() + s.hashFunc.Write(b) + return s.hashFunc.Sum(nil) +} + +// CreateString getting hash as string +func (s *_sig) CreateString(b []byte) string { + return hex.EncodeToString(s.Create(b)) +} + +// Validate signature +func (s *_sig) Validate(b []byte, ex string) bool { + v, err := hex.DecodeString(ex) + if err != nil { + return false + } + return hmac.Equal(s.Create(b), v) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + +// _store storage +type ( + _store struct { + list map[string]Signature + lock sync.RWMutex + } + + Storage interface { + Add(s Signature) + Get(id string) Signature + Count() int + Del(id string) + Flush() + } +) + +// NewStorage init storage +func NewStorage() Storage { + return &_store{ + list: make(map[string]Signature), + } +} + +// Add adding to storage +func (ss *_store) Add(s Signature) { + ss.lock.Lock() + defer ss.lock.Unlock() + + ss.list[s.ID()] = s +} + +// Get getting to storage +func (ss *_store) Get(id string) Signature { + ss.lock.RLock() + defer ss.lock.RUnlock() + + if s, ok := ss.list[id]; ok { + return s + } + return nil +} + +// Count count sign in storage +func (ss *_store) Count() int { + ss.lock.RLock() + defer ss.lock.RUnlock() + l := len(ss.list) + return l +} + +// Del deleting from storage +func (ss *_store) Del(id string) { + ss.lock.Lock() + defer ss.lock.Unlock() + + delete(ss.list, id) +} + +// Flush removing all from storage +func (ss *_store) Flush() { + ss.lock.Lock() + defer ss.lock.Unlock() + + for k := range ss.list { + delete(ss.list, k) + } +} diff --git a/sdk/webutil/signature/signer_test.go b/sdk/webutil/signature/signer_test.go new file mode 100644 index 0000000..293b32f --- /dev/null +++ b/sdk/webutil/signature/signer_test.go @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package signature_test + +import ( + "testing" + + "github.com/osspkg/goppy/sdk/webutil/signature" + "github.com/stretchr/testify/require" +) + +func TestUnit_Signature(t *testing.T) { + sign := signature.NewSHA256("123", "456") + + body := []byte("hello") + hash := "b7089b0463bf766946fc467102671dbe91659f17a7a19145cd68138c36b00555" + + require.Equal(t, "123", sign.ID()) + require.Equal(t, hash, sign.CreateString(body)) + require.True(t, sign.Validate(body, hash)) +} + +func TestUnit_Storage(t *testing.T) { + store := signature.NewStorage() + + store.Add(signature.NewSHA256("1", "0")) + store.Add(signature.NewSHA256("2", "0")) + store.Add(signature.NewSHA256("3", "0")) + store.Add(signature.NewSHA256("5", "0")) + require.Equal(t, 4, store.Count()) + + store.Add(signature.NewMD5("5", "0")) + require.Equal(t, 4, store.Count()) + + require.Nil(t, store.Get("4")) + s := store.Get("5") + require.NotNil(t, s) + require.Equal(t, "5", s.ID()) + require.Equal(t, "hmac-md5", s.Algorithm()) +} diff --git a/sdk/webutil/version/common.go b/sdk/webutil/version/common.go new file mode 100644 index 0000000..4532e0a --- /dev/null +++ b/sdk/webutil/version/common.go @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023 Mikhail Knyazhev . All rights reserved. + * Use of this source code is governed by a BSD 3-Clause license that can be found in the LICENSE file. + */ + +package version + +import ( + "fmt" + "net/http" + "regexp" + "strconv" +) + +const ( + Header = `Accept` + valueRegexp = `application\/vnd.v(\d+)\+json` + valueTmpl = `application/vnd.v%d+json` +) + +var rex = regexp.MustCompile(valueRegexp) + +// Decode getting version from header +func Decode(h http.Header) uint64 { + d := h.Get(Header) + r := rex.FindSubmatch([]byte(d)) + if len(r) == 2 { + if v, err := strconv.ParseUint(string(r[1]), 10, 64); err == nil { + return v + } + } + return 0 +} + +// Encode setting version to header +func Encode(h http.Header, v uint64) { + h.Set(Header, fmt.Sprintf(valueTmpl, v)) +}