From 8329bee93cf0e4e73598b6d5a2947b2047484df6 Mon Sep 17 00:00:00 2001 From: qwqcode Date: Wed, 15 Jun 2022 18:28:26 +0800 Subject: [PATCH] feat: Global site & origin checker & support cookie Signed-off-by: qwqcode --- config/config.go | 5 ++ http/a_http.go | 5 +- http/a_router.go | 1 + http/a_site_origin.go | 147 +++++++++++++++++++++++++++++++++++++ http/admin.go | 14 ++++ http/admin_comment_del.go | 8 +- http/admin_comment_edit.go | 10 +-- http/admin_page_del.go | 8 +- http/admin_page_edit.go | 10 +-- http/admin_page_fetch.go | 6 +- http/admin_page_get.go | 8 +- http/comment_add.go | 8 +- http/comment_get.go | 8 +- http/img_upload.go | 8 +- http/mark_read.go | 8 +- http/pv.go | 13 ++-- http/stat.go | 8 +- http/user_get.go | 13 +--- http/user_login.go | 36 ++++++++- http/user_logout.go | 23 ++++++ http/utils.go | 123 ------------------------------- http/vote.go | 8 +- lib/constants.go | 8 ++ 23 files changed, 283 insertions(+), 203 deletions(-) create mode 100644 http/a_site_origin.go create mode 100644 http/user_logout.go diff --git a/config/config.go b/config/config.go index 1ddfc3f..8ad6df2 100644 --- a/config/config.go +++ b/config/config.go @@ -19,6 +19,7 @@ type Config struct { SiteDefault string `mapstructure:"site_default" json:"site_default"` // 默认站点名(当请求无指定 site_name 时使用) AdminUsers []AdminUserConf `mapstructure:"admin_users" json:"admin_users"` // 管理员账户 LoginTimeout int `mapstructure:"login_timeout" json:"login_timeout"` // 登陆超时 + Cookie CookieConf `mapstructure:"cookie" json:"cookie"` // Cookie Moderator ModeratorConf `mapstructure:"moderator" json:"moderator"` // 评论审查 Captcha CaptchaConf `mapstructure:"captcha" json:"captcha"` // 验证码 Email EmailConf `mapstructure:"email" json:"email"` // 邮箱提醒 @@ -87,6 +88,10 @@ type AdminUserConf struct { Sites []string `mapstructure:"sites" json:"sites"` } +type CookieConf struct { + Enabled bool `mapstructure:"enabled" json:"enabled"` +} + type ModeratorConf struct { PendingDefault bool `mapstructure:"pending_default" json:"pending_default"` ApiFailBlock bool `mapstructure:"api_fail_block" json:"api_fail_block"` // API 请求错误仍然拦截 diff --git a/http/a_http.go b/http/a_http.go index ad21686..a1910e4 100644 --- a/http/a_http.go +++ b/http/a_http.go @@ -78,8 +78,11 @@ func InitCorsControl(e *echo.Echo) { } e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: allowOrigins, + AllowOrigins: allowOrigins, + AllowCredentials: true, // allow cors with cookies })) + + e.Use(SiteOriginMiddleware()) } // 从完整 URL 中提取出 Scheme + Host 部分 (保留端口号) diff --git a/http/a_router.go b/http/a_router.go index 077498b..67d22d5 100644 --- a/http/a_router.go +++ b/http/a_router.go @@ -34,6 +34,7 @@ func InitRouter(e *echo.Echo) { api.POST("/user-get", action.UserGet) api.POST("/login", action.Login) api.POST("/login-status", action.LoginStatus) + api.POST("/logout", action.Logout) api.POST("/mark-read", action.MarkRead) api.POST("/vote", action.Vote) api.POST("/pv", action.PV) diff --git a/http/a_site_origin.go b/http/a_site_origin.go new file mode 100644 index 0000000..27188fe --- /dev/null +++ b/http/a_site_origin.go @@ -0,0 +1,147 @@ +package http + +import ( + "fmt" + "net/url" + "strings" + + "github.com/ArtalkJS/ArtalkGo/config" + "github.com/ArtalkJS/ArtalkGo/lib" + "github.com/ArtalkJS/ArtalkGo/model" + "github.com/labstack/echo/v4" +) + +// 站点隔离 & Origin 控制 +func SiteOriginMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + siteName := c.FormValue("site_name") + siteID := uint(0) + var site *model.Site = nil + + siteAll := false + + // 请求站点名 == "__ATK_SITE_ALL" 时取消站点隔离 + if siteName == lib.ATK_SITE_ALL { + if !CheckIsAdminReq(c) { + return RespError(c, "仅管理员查询允许取消站点隔离") + } + + siteAll = true + } else { + // 请求站点名为空,使用默认 site + if siteName == "" { + siteName = strings.TrimSpace(config.Instance.SiteDefault) + if siteName != "" { + model.FindCreateSite(siteName) // 默认站点不存在则创建 + } + } + + findSite := model.FindSite(siteName) + if findSite.IsEmpty() { + return RespError(c, fmt.Sprintf("未找到站点:`%s`,请控制台创建站点", siteName), Map{ + "err_no_site": true, + }) + } + site = &findSite + siteID = findSite.ID + } + + // 检测 Origin 合法性 (防止 CSRF 攻击) + if isOK, resp := CheckOrigin(c, site); !isOK { + return resp + } + + // 设置上下文 + c.Set(lib.CTX_KEY_ATK_SITE_ID, siteID) + c.Set(lib.CTX_KEY_ATK_SITE_NAME, siteName) + c.Set(lib.CTX_KEY_ATK_SITE_ALL, siteAll) + + return next(c) + } + } +} + +func UseSite(c echo.Context, siteName *string, destID *uint, destSiteAll *bool) { + if destID != nil { + *destID = c.Get(lib.CTX_KEY_ATK_SITE_ID).(uint) + } + if siteName != nil { + *siteName = c.Get(lib.CTX_KEY_ATK_SITE_NAME).(string) + } + if destSiteAll != nil { + *destSiteAll = c.Get(lib.CTX_KEY_ATK_SITE_ALL).(bool) + } +} + +// 检测 Origin 合法性 +// 防止 CSRF 攻击 +// @see https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html +func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) { + // 可信来源 URL + allowSrcURLs := []string{} + + // 用户配置 + allowSrcURLs = append(allowSrcURLs, config.Instance.TrustedDomains...) // 允许配置文件域名 + if allowSite != nil { + allowSrcURLs = append(allowSrcURLs, allowSite.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名 + } + if len(allowSrcURLs) == 0 { + return true, nil // 若用户配置列表中无数据,则取消控制 + } + if lib.ContainsStr(allowSrcURLs, "*") { + return true, nil // 列表中出现通配符关闭控制 + } + + host := c.Request().Host + realHostUnderProxy := c.Request().Header.Get("X-Forwarded-Host") + if realHostUnderProxy != "" { + host = realHostUnderProxy + } + + // 读取 Origin 数据 + // @note Origin 标头在前端 fetch POST 操作中总是携带的, + // 即使配置 Referrer-Policy: no-referrer + // @see https://stackoverflow.com/questions/42239643/when-do-browsers-send-the-origin-header-when-do-browsers-set-the-origin-to-null + origin := c.Request().Header.Get(echo.HeaderOrigin) + if origin == "" || origin == "null" { + // 从 Referer 获取 Origin + referer := c.Request().Referer() + if referer == "" { + return false, RespError(c, "无效请求,Origin 无法获取") + } + origin = referer + } + + pOrigin, err := url.Parse(origin) + if err != nil { + return false, RespError(c, "Origin 不合法") + } + + // 系统配置:默认允许来自相同域名的请求 + allowSrcURLs = append(allowSrcURLs, c.Scheme()+"://"+host) + + allowSrcURLs = lib.RemoveDuplicates(allowSrcURLs) // 去重 + for _, a := range allowSrcURLs { + a = strings.TrimSpace(a) + if a == "" { + continue + } + pAllow, err := url.Parse(a) + if err != nil { + continue + } + + // 在可信来源列表中匹配 Referer 的 host 部分 (含端口) 则放行 + // @see https://web.dev/referrer-best-practices/ + // Referrer-Policy 不能设为 no-referer, + // Chrome v85+ 默认为:strict-origin-when-cross-origin。 + // 前端页面 head 不配置 , + // 浏览器默认都会至少携带 Origin 数据 (不带 path,但包含端口) + if pAllow.Host == pOrigin.Host { + return true, nil + } + } + + return false, RespError(c, "非法请求,请检查可信域名配置") +} diff --git a/http/admin.go b/http/admin.go index 6033829..80146ed 100644 --- a/http/admin.go +++ b/http/admin.go @@ -47,6 +47,17 @@ func LoginGetUserToken(user model.User) string { return t } +func GetJwtStrByReqCookie(c echo.Context) string { + if !config.Instance.Cookie.Enabled { + return "" + } + cookie, err := c.Cookie(lib.COOKIE_KEY_ATK_AUTH) + if err != nil { + return "" + } + return cookie.Value +} + func GetJwtInstanceByReq(c echo.Context) *jwt.Token { token := c.QueryParam("token") if token == "" { @@ -56,6 +67,9 @@ func GetJwtInstanceByReq(c echo.Context) *jwt.Token { token = c.Request().Header.Get("Authorization") token = strings.TrimPrefix(token, "Bearer ") } + if token == "" { + token = GetJwtStrByReqCookie(c) + } if token == "" { return nil } diff --git a/http/admin_comment_del.go b/http/admin_comment_del.go index 96538ec..e1b6d32 100644 --- a/http/admin_comment_del.go +++ b/http/admin_comment_del.go @@ -8,7 +8,7 @@ import ( type ParamsCommentDel struct { ID uint `mapstructure:"id" param:"required"` - SiteName string `mapstructure:"site_name"` + SiteName string SiteID uint SiteAll bool } @@ -19,10 +19,8 @@ func (a *action) AdminCommentDel(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // find comment comment := model.FindComment(p.ID) diff --git a/http/admin_comment_edit.go b/http/admin_comment_edit.go index 0e19eda..ed3dbab 100644 --- a/http/admin_comment_edit.go +++ b/http/admin_comment_edit.go @@ -10,8 +10,8 @@ import ( type ParamsCommentEdit struct { // 查询值 - ID uint `mapstructure:"id" param:"required"` - SiteName string `mapstructure:"site_name"` + ID uint `mapstructure:"id" param:"required"` + SiteName string SiteID uint SiteAll bool @@ -35,10 +35,8 @@ func (a *action) AdminCommentEdit(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // find comment comment := model.FindComment(p.ID) diff --git a/http/admin_page_del.go b/http/admin_page_del.go index f10b008..52e0932 100644 --- a/http/admin_page_del.go +++ b/http/admin_page_del.go @@ -7,7 +7,7 @@ import ( type ParamsAdminPageDel struct { Key string `mapstructure:"key" param:"required"` - SiteName string `mapstructure:"site_name"` + SiteName string SiteID uint } @@ -17,10 +17,8 @@ func (a *action) AdminPageDel(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, nil); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, nil) page := model.FindPage(p.Key, p.SiteName) if page.IsEmpty() { diff --git a/http/admin_page_edit.go b/http/admin_page_edit.go index ee7c705..5416755 100644 --- a/http/admin_page_edit.go +++ b/http/admin_page_edit.go @@ -9,8 +9,8 @@ import ( type ParamsAdminPageEdit struct { // 查询值 - ID uint `mapstructure:"id"` - SiteName string `mapstructure:"site_name"` + ID uint `mapstructure:"id"` + SiteName string SiteID uint // 修改值 @@ -29,10 +29,8 @@ func (a *action) AdminPageEdit(c echo.Context) error { return RespError(c, "page key 不能为空白字符") } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, nil); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, nil) // find page var page = model.FindPageByID(p.ID) diff --git a/http/admin_page_fetch.go b/http/admin_page_fetch.go index 319989f..c06b06c 100644 --- a/http/admin_page_fetch.go +++ b/http/admin_page_fetch.go @@ -10,8 +10,8 @@ import ( ) type ParamsAdminPageFetch struct { - ID uint `mapstructure:"id"` - SiteName string `mapstructure:"site_name"` + ID uint `mapstructure:"id"` + SiteName string GetStatus bool `mapstructure:"get_status"` } @@ -26,6 +26,8 @@ func (a *action) AdminPageFetch(c echo.Context) error { return resp } + UseSite(c, &p.SiteName, nil, nil) + // 状态获取 if p.GetStatus { if allPageFetching { diff --git a/http/admin_page_get.go b/http/admin_page_get.go index 3a53533..0151a3a 100644 --- a/http/admin_page_get.go +++ b/http/admin_page_get.go @@ -6,7 +6,7 @@ import ( ) type ParamsAdminPageGet struct { - SiteName string `mapstructure:"site_name"` + SiteName string SiteID uint SiteAll bool Limit int `mapstructure:"limit"` @@ -24,10 +24,8 @@ func (a *action) AdminPageGet(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) if !IsAdminHasSiteAccess(c, p.SiteName) { return RespError(c, "无权操作") diff --git a/http/comment_add.go b/http/comment_add.go index 09d771b..779d221 100644 --- a/http/comment_add.go +++ b/http/comment_add.go @@ -25,7 +25,7 @@ type ParamsAdd struct { PageTitle string `mapstructure:"page_title"` Token string `mapstructure:"token"` - SiteName string `mapstructure:"site_name"` + SiteName string SiteID uint } @@ -64,10 +64,8 @@ func (a *action) Add(c echo.Context) error { // record action for limiting action RecordAction(c) - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, nil); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, nil) // find page page := model.FindCreatePage(p.PageKey, p.PageTitle, p.SiteName) diff --git a/http/comment_get.go b/http/comment_get.go index bf92705..4fcabd6 100644 --- a/http/comment_get.go +++ b/http/comment_get.go @@ -8,7 +8,7 @@ import ( type ParamsGet struct { PageKey string `mapstructure:"page_key" param:"required"` - SiteName string `mapstructure:"site_name"` + SiteName string Limit int `mapstructure:"limit"` Offset int `mapstructure:"offset"` @@ -47,10 +47,8 @@ func (a *action) Get(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // find page var page model.Page diff --git a/http/img_upload.go b/http/img_upload.go index 78843b5..e65ac32 100644 --- a/http/img_upload.go +++ b/http/img_upload.go @@ -27,7 +27,7 @@ type ParamsImgUpload struct { PageKey string `mapstructure:"page_key" param:"required"` PageTitle string `mapstructure:"page_title"` - SiteName string `mapstructure:"site_name"` + SiteName string SiteID uint SiteAll bool @@ -51,10 +51,8 @@ func (a *action) ImgUpload(c echo.Context) error { return RespError(c, "Invalid email") } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // 记录请求次数 (for 请求频率限制) RecordAction(c) diff --git a/http/mark_read.go b/http/mark_read.go index 03579f8..f870ff3 100644 --- a/http/mark_read.go +++ b/http/mark_read.go @@ -12,7 +12,7 @@ type ParamsMarkRead struct { Email string `mapstructure:"email"` AllRead bool `mapstructure:"all_read"` - SiteName string `mapstructure:"site_name"` + SiteName string SiteID uint SiteAll bool } @@ -23,10 +23,8 @@ func (a *action) MarkRead(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // all read if p.AllRead { diff --git a/http/pv.go b/http/pv.go index 7b498e7..c607769 100644 --- a/http/pv.go +++ b/http/pv.go @@ -9,10 +9,9 @@ type ParamsPV struct { PageKey string `mapstructure:"page_key" param:"required"` PageTitle string `mapstructure:"page_title"` - SiteName string `mapstructure:"site_name"` - - SiteID uint - SiteAll bool + SiteName string + SiteID uint + SiteAll bool } func (a *action) PV(c echo.Context) error { @@ -21,10 +20,8 @@ func (a *action) PV(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // find page page := model.FindCreatePage(p.PageKey, p.PageTitle, p.SiteName) diff --git a/http/stat.go b/http/stat.go index 473503e..2534481 100644 --- a/http/stat.go +++ b/http/stat.go @@ -11,7 +11,7 @@ import ( type ParamsStat struct { Type string `mapstructure:"type" param:"required"` - SiteName string `mapstructure:"site_name"` + SiteName string PageKeys string `mapstructure:"page_keys"` Limit int `mapstructure:"limit"` @@ -26,10 +26,8 @@ func (a *action) Stat(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // Limit 限定 if p.Limit <= 0 { diff --git a/http/user_get.go b/http/user_get.go index c125942..e45b14e 100644 --- a/http/user_get.go +++ b/http/user_get.go @@ -16,24 +16,19 @@ func (a *action) UserGet(c echo.Context) error { return resp } - user := model.FindUser(p.Name, p.Email) + // login status + isLogin := !GetUserByReq(c).IsEmpty() + user := model.FindUser(p.Name, p.Email) if user.IsEmpty() { return RespData(c, Map{ "user": nil, - "is_login": false, + "is_login": isLogin, "unread": []interface{}{}, "unread_count": 0, }) } - // loginned user check - isLogin := false - tUser := GetUserByReq(c) - if tUser.Name == p.Name && tUser.Email == p.Email { - isLogin = true - } - // unread notifies unreadNotifies := model.FindUnreadNotifies(user.ID) diff --git a/http/user_login.go b/http/user_login.go index e7acbe5..767f5e2 100644 --- a/http/user_login.go +++ b/http/user_login.go @@ -3,8 +3,12 @@ package http import ( "crypto/md5" "fmt" + "net/http" "strings" + "time" + "github.com/ArtalkJS/ArtalkGo/config" + "github.com/ArtalkJS/ArtalkGo/lib" "github.com/ArtalkJS/ArtalkGo/model" "github.com/labstack/echo/v4" "golang.org/x/crypto/bcrypt" @@ -25,8 +29,7 @@ func (a *action) Login(c echo.Context) error { // record action for limiting action RecordAction(c) - user := model.FindUser(p.Name, p.Email) // name = ? OR email = ? - + user := model.FindUser(p.Name, p.Email) // name = ? AND email = ? if user.IsEmpty() { return RespError(c, "验证失败") } @@ -60,11 +63,38 @@ func (a *action) Login(c echo.Context) error { return RespError(c, "验证失败") } + jwtToken := LoginGetUserToken(user) + setAuthCookie(c, jwtToken, time.Now().Add(time.Second*time.Duration(config.Instance.LoginTimeout))) + return RespData(c, Map{ - "token": LoginGetUserToken(user), + "token": jwtToken, }) } +func setAuthCookie(c echo.Context, jwtToken string, expires time.Time) { + if !config.Instance.Cookie.Enabled { + return + } + + // save jwt token to cookie + cookie := new(http.Cookie) + cookie.Name = lib.COOKIE_KEY_ATK_AUTH + cookie.Value = jwtToken + cookie.Expires = expires + + // @see https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Cookies + // @see https://owasp.org/www-project-web-security-testing-guide/v41/4-Web_Application_Security_Testing/06-Session_Management_Testing/02-Testing_for_Cookies_Attributes + cookie.Path = "/" + cookie.HttpOnly = true // prevent XSS + cookie.Secure = true // https only + cookie.SameSite = http.SameSiteDefaultMode // for cors-request + + // @note cookie secure is not working on localhost + // @see https://bugs.chromium.org/p/chromium/issues/detail?id=1177877#c7 + + c.SetCookie(cookie) +} + func HashPassword(password string) (string, error) { bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14) return string(bytes), err diff --git a/http/user_logout.go b/http/user_logout.go new file mode 100644 index 0000000..ef888de --- /dev/null +++ b/http/user_logout.go @@ -0,0 +1,23 @@ +package http + +import ( + "time" + + "github.com/ArtalkJS/ArtalkGo/config" + "github.com/labstack/echo/v4" +) + +func (a *action) Logout(c echo.Context) error { + if !config.Instance.Cookie.Enabled { + return RespError(c, "API 未启用 Cookie") + } + + if GetJwtStrByReqCookie(c) == "" { + return RespError(c, "未登录,无需注销") + } + + // same as login, remove cookie + setAuthCookie(c, "", time.Now().AddDate(0, 0, -1)) + + return RespSuccess(c) +} diff --git a/http/utils.go b/http/utils.go index 73c6662..b7044bb 100644 --- a/http/utils.go +++ b/http/utils.go @@ -1,14 +1,10 @@ package http import ( - "fmt" - "net/url" "reflect" "strconv" "strings" - "github.com/ArtalkJS/ArtalkGo/config" - "github.com/ArtalkJS/ArtalkGo/lib" "github.com/ArtalkJS/ArtalkGo/model" "github.com/labstack/echo/v4" "github.com/mitchellh/mapstructure" @@ -105,122 +101,3 @@ func CheckIsAllowed(c echo.Context, name string, email string, page model.Page, return true, nil } - -// 检测 Origin 合法性 -// 防止 CSRF 攻击 -// @see https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html -func CheckOrigin(c echo.Context, site model.Site) (bool, error) { - isAdminReq := CheckIsAdminReq(c) - if isAdminReq { - return true, nil // 管理员直接允许 - } - - // 可信来源 URL - allowSrcURLs := []string{} - - // 用户配置 - allowSrcURLs = append(allowSrcURLs, config.Instance.TrustedDomains...) // 允许配置文件域名 - allowSrcURLs = append(allowSrcURLs, site.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名 - if len(allowSrcURLs) == 0 { - return true, nil // 若用户配置列表中无数据,则取消控制 - } - if lib.ContainsStr(allowSrcURLs, "*") { - return true, nil // 列表中出现通配符关闭控制 - } - - host := c.Request().Host - realHostUnderProxy := c.Request().Header.Get("X-Forwarded-Host") - if realHostUnderProxy != "" { - host = realHostUnderProxy - } - - // 读取 Origin 数据 - // @note Origin 标头在前端 fetch POST 操作中总是携带的, - // 即使配置 Referrer-Policy: no-referrer - // @see https://stackoverflow.com/questions/42239643/when-do-browsers-send-the-origin-header-when-do-browsers-set-the-origin-to-null - origin := c.Request().Header.Get(echo.HeaderOrigin) - if origin == "" || origin == "null" { - // 从 Referer 获取 Origin - referer := c.Request().Referer() - if referer == "" { - return false, RespError(c, "无效请求,Origin 无法获取") - } - origin = referer - } - - pOrigin, err := url.Parse(origin) - if err != nil { - return false, RespError(c, "Origin 不合法") - } - - // 系统配置:默认允许来自相同域名的请求 - allowSrcURLs = append(allowSrcURLs, c.Scheme()+"://"+host) - - allowSrcURLs = lib.RemoveDuplicates(allowSrcURLs) // 去重 - for _, a := range allowSrcURLs { - a = strings.TrimSpace(a) - if a == "" { - continue - } - pAllow, err := url.Parse(a) - if err != nil { - continue - } - - // 在可信来源列表中匹配 Referer 的 host 部分 (含端口) 则放行 - // @see https://web.dev/referrer-best-practices/ - // Referrer-Policy 不能设为 no-referer, - // Chrome v85+ 默认为:strict-origin-when-cross-origin。 - // 前端页面 head 不配置 , - // 浏览器默认都会至少携带 Origin 数据 (不带 path,但包含端口) - if pAllow.Host == pOrigin.Host { - return true, nil - } - } - - return false, RespError(c, "非法请求,请检查可信域名配置") -} - -func CheckSite(c echo.Context, siteName *string, destID *uint, destSiteAll *bool) (bool, error) { - // 启用源 SiteAll - if destSiteAll != nil { - // 传入站点名参数 == "__ATK_SITE_ALL" 时取消站点隔离 - if *siteName == lib.ATK_SITE_ALL { - if !CheckIsAdminReq(c) { - return false, RespError(c, "仅管理员查询允许取消站点隔离") - } - *destSiteAll = true - return true, nil - } else { - *destSiteAll = false - } - } - - if *siteName == "" { - // 传入值为空,使用默认 site - siteDefault := strings.TrimSpace(config.Instance.SiteDefault) - if siteDefault != "" { - // 没有则创建 - model.FindCreateSite(siteDefault) - } - *siteName = siteDefault // 更新源 name - - return true, nil - } - - site := model.FindSite(*siteName) - if site.IsEmpty() { - return false, RespError(c, fmt.Sprintf("未找到站点:`%s`,请控制台创建站点", *siteName), Map{ - "err_no_site": true, - }) - } - - // 检测 Origin 合法性 - if isOK, resp := CheckOrigin(c, site); !isOK { - return false, resp - } - - *destID = site.ID // 更新源 id - - return true, nil -} diff --git a/http/vote.go b/http/vote.go index b867d80..eb45edd 100644 --- a/http/vote.go +++ b/http/vote.go @@ -14,7 +14,7 @@ type ParamsVote struct { Name string `mapstructure:"name"` Email string `mapstructure:"email"` - SiteName string `mapstructure:"site_name"` + SiteName string SiteID uint SiteAll bool } @@ -25,10 +25,8 @@ func (a *action) Vote(c echo.Context) error { return resp } - // find site - if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK { - return resp - } + // use site + UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll) // find user var user model.User diff --git a/lib/constants.go b/lib/constants.go index 0efd833..f4db42d 100644 --- a/lib/constants.go +++ b/lib/constants.go @@ -4,3 +4,11 @@ package lib // 所有站点 const ATK_SITE_ALL = "__ATK_SITE_ALL" + +// Cookie 键 +const COOKIE_KEY_ATK_AUTH = "ATK_AUTH" + +// ctx keys +const CTX_KEY_ATK_SITE_ID = "atk_site_id" +const CTX_KEY_ATK_SITE_NAME = "atk_site_name" +const CTX_KEY_ATK_SITE_ALL = "atk_site_all"