Skip to content
This repository has been archived by the owner on Feb 28, 2023. It is now read-only.

Commit

Permalink
feat: Global site & origin checker & support cookie
Browse files Browse the repository at this point in the history
Signed-off-by: qwqcode <[email protected]>
  • Loading branch information
qwqcode committed Jun 15, 2022
1 parent 0b9f364 commit 8329bee
Show file tree
Hide file tree
Showing 23 changed files with 283 additions and 203 deletions.
5 changes: 5 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Config struct {
SiteDefault string `mapstructure:"site_default" json:"site_default"` // 默认站点名(当请求无指定 site_name 时使用)
AdminUsers []AdminUserConf `mapstructure:"admin_users" json:"admin_users"` // 管理员账户
LoginTimeout int `mapstructure:"login_timeout" json:"login_timeout"` // 登陆超时
Cookie CookieConf `mapstructure:"cookie" json:"cookie"` // Cookie
Moderator ModeratorConf `mapstructure:"moderator" json:"moderator"` // 评论审查
Captcha CaptchaConf `mapstructure:"captcha" json:"captcha"` // 验证码
Email EmailConf `mapstructure:"email" json:"email"` // 邮箱提醒
Expand Down Expand Up @@ -87,6 +88,10 @@ type AdminUserConf struct {
Sites []string `mapstructure:"sites" json:"sites"`
}

type CookieConf struct {
Enabled bool `mapstructure:"enabled" json:"enabled"`
}

type ModeratorConf struct {
PendingDefault bool `mapstructure:"pending_default" json:"pending_default"`
ApiFailBlock bool `mapstructure:"api_fail_block" json:"api_fail_block"` // API 请求错误仍然拦截
Expand Down
5 changes: 4 additions & 1 deletion http/a_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ func InitCorsControl(e *echo.Echo) {
}

e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: allowOrigins,
AllowOrigins: allowOrigins,
AllowCredentials: true, // allow cors with cookies
}))

e.Use(SiteOriginMiddleware())
}

// 从完整 URL 中提取出 Scheme + Host 部分 (保留端口号)
Expand Down
1 change: 1 addition & 0 deletions http/a_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func InitRouter(e *echo.Echo) {
api.POST("/user-get", action.UserGet)
api.POST("/login", action.Login)
api.POST("/login-status", action.LoginStatus)
api.POST("/logout", action.Logout)
api.POST("/mark-read", action.MarkRead)
api.POST("/vote", action.Vote)
api.POST("/pv", action.PV)
Expand Down
147 changes: 147 additions & 0 deletions http/a_site_origin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package http

import (
"fmt"
"net/url"
"strings"

"github.com/ArtalkJS/ArtalkGo/config"
"github.com/ArtalkJS/ArtalkGo/lib"
"github.com/ArtalkJS/ArtalkGo/model"
"github.com/labstack/echo/v4"
)

// 站点隔离 & Origin 控制
func SiteOriginMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
siteName := c.FormValue("site_name")
siteID := uint(0)
var site *model.Site = nil

siteAll := false

// 请求站点名 == "__ATK_SITE_ALL" 时取消站点隔离
if siteName == lib.ATK_SITE_ALL {
if !CheckIsAdminReq(c) {
return RespError(c, "仅管理员查询允许取消站点隔离")
}

siteAll = true
} else {
// 请求站点名为空,使用默认 site
if siteName == "" {
siteName = strings.TrimSpace(config.Instance.SiteDefault)
if siteName != "" {
model.FindCreateSite(siteName) // 默认站点不存在则创建
}
}

findSite := model.FindSite(siteName)
if findSite.IsEmpty() {
return RespError(c, fmt.Sprintf("未找到站点:`%s`,请控制台创建站点", siteName), Map{
"err_no_site": true,
})
}
site = &findSite
siteID = findSite.ID
}

// 检测 Origin 合法性 (防止 CSRF 攻击)
if isOK, resp := CheckOrigin(c, site); !isOK {
return resp
}

// 设置上下文
c.Set(lib.CTX_KEY_ATK_SITE_ID, siteID)
c.Set(lib.CTX_KEY_ATK_SITE_NAME, siteName)
c.Set(lib.CTX_KEY_ATK_SITE_ALL, siteAll)

return next(c)
}
}
}

func UseSite(c echo.Context, siteName *string, destID *uint, destSiteAll *bool) {
if destID != nil {
*destID = c.Get(lib.CTX_KEY_ATK_SITE_ID).(uint)
}
if siteName != nil {
*siteName = c.Get(lib.CTX_KEY_ATK_SITE_NAME).(string)
}
if destSiteAll != nil {
*destSiteAll = c.Get(lib.CTX_KEY_ATK_SITE_ALL).(bool)
}
}

// 检测 Origin 合法性
// 防止 CSRF 攻击
// @see https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) {
// 可信来源 URL
allowSrcURLs := []string{}

// 用户配置
allowSrcURLs = append(allowSrcURLs, config.Instance.TrustedDomains...) // 允许配置文件域名
if allowSite != nil {
allowSrcURLs = append(allowSrcURLs, allowSite.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名
}
if len(allowSrcURLs) == 0 {
return true, nil // 若用户配置列表中无数据,则取消控制
}
if lib.ContainsStr(allowSrcURLs, "*") {
return true, nil // 列表中出现通配符关闭控制
}

host := c.Request().Host
realHostUnderProxy := c.Request().Header.Get("X-Forwarded-Host")
if realHostUnderProxy != "" {
host = realHostUnderProxy
}

// 读取 Origin 数据
// @note Origin 标头在前端 fetch POST 操作中总是携带的,
// 即使配置 Referrer-Policy: no-referrer
// @see https://stackoverflow.com/questions/42239643/when-do-browsers-send-the-origin-header-when-do-browsers-set-the-origin-to-null
origin := c.Request().Header.Get(echo.HeaderOrigin)
if origin == "" || origin == "null" {
// 从 Referer 获取 Origin
referer := c.Request().Referer()
if referer == "" {
return false, RespError(c, "无效请求,Origin 无法获取")
}
origin = referer
}

pOrigin, err := url.Parse(origin)
if err != nil {
return false, RespError(c, "Origin 不合法")
}

// 系统配置:默认允许来自相同域名的请求
allowSrcURLs = append(allowSrcURLs, c.Scheme()+"://"+host)

allowSrcURLs = lib.RemoveDuplicates(allowSrcURLs) // 去重
for _, a := range allowSrcURLs {
a = strings.TrimSpace(a)
if a == "" {
continue
}
pAllow, err := url.Parse(a)
if err != nil {
continue
}

// 在可信来源列表中匹配 Referer 的 host 部分 (含端口) 则放行
// @see https://web.dev/referrer-best-practices/
// Referrer-Policy 不能设为 no-referer,
// Chrome v85+ 默认为:strict-origin-when-cross-origin。
// 前端页面 head 不配置 <meta name="referrer" content="no-referer" />,
// 浏览器默认都会至少携带 Origin 数据 (不带 path,但包含端口)
if pAllow.Host == pOrigin.Host {
return true, nil
}
}

return false, RespError(c, "非法请求,请检查可信域名配置")
}
14 changes: 14 additions & 0 deletions http/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ func LoginGetUserToken(user model.User) string {
return t
}

func GetJwtStrByReqCookie(c echo.Context) string {
if !config.Instance.Cookie.Enabled {
return ""
}
cookie, err := c.Cookie(lib.COOKIE_KEY_ATK_AUTH)
if err != nil {
return ""
}
return cookie.Value
}

func GetJwtInstanceByReq(c echo.Context) *jwt.Token {
token := c.QueryParam("token")
if token == "" {
Expand All @@ -56,6 +67,9 @@ func GetJwtInstanceByReq(c echo.Context) *jwt.Token {
token = c.Request().Header.Get("Authorization")
token = strings.TrimPrefix(token, "Bearer ")
}
if token == "" {
token = GetJwtStrByReqCookie(c)
}
if token == "" {
return nil
}
Expand Down
8 changes: 3 additions & 5 deletions http/admin_comment_del.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
type ParamsCommentDel struct {
ID uint `mapstructure:"id" param:"required"`

SiteName string `mapstructure:"site_name"`
SiteName string
SiteID uint
SiteAll bool
}
Expand All @@ -19,10 +19,8 @@ func (a *action) AdminCommentDel(c echo.Context) error {
return resp
}

// find site
if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK {
return resp
}
// use site
UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll)

// find comment
comment := model.FindComment(p.ID)
Expand Down
10 changes: 4 additions & 6 deletions http/admin_comment_edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (

type ParamsCommentEdit struct {
// 查询值
ID uint `mapstructure:"id" param:"required"`
SiteName string `mapstructure:"site_name"`
ID uint `mapstructure:"id" param:"required"`
SiteName string
SiteID uint
SiteAll bool

Expand All @@ -35,10 +35,8 @@ func (a *action) AdminCommentEdit(c echo.Context) error {
return resp
}

// find site
if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK {
return resp
}
// use site
UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll)

// find comment
comment := model.FindComment(p.ID)
Expand Down
8 changes: 3 additions & 5 deletions http/admin_page_del.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

type ParamsAdminPageDel struct {
Key string `mapstructure:"key" param:"required"`
SiteName string `mapstructure:"site_name"`
SiteName string
SiteID uint
}

Expand All @@ -17,10 +17,8 @@ func (a *action) AdminPageDel(c echo.Context) error {
return resp
}

// find site
if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, nil); !isOK {
return resp
}
// use site
UseSite(c, &p.SiteName, &p.SiteID, nil)

page := model.FindPage(p.Key, p.SiteName)
if page.IsEmpty() {
Expand Down
10 changes: 4 additions & 6 deletions http/admin_page_edit.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (

type ParamsAdminPageEdit struct {
// 查询值
ID uint `mapstructure:"id"`
SiteName string `mapstructure:"site_name"`
ID uint `mapstructure:"id"`
SiteName string
SiteID uint

// 修改值
Expand All @@ -29,10 +29,8 @@ func (a *action) AdminPageEdit(c echo.Context) error {
return RespError(c, "page key 不能为空白字符")
}

// find site
if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, nil); !isOK {
return resp
}
// use site
UseSite(c, &p.SiteName, &p.SiteID, nil)

// find page
var page = model.FindPageByID(p.ID)
Expand Down
6 changes: 4 additions & 2 deletions http/admin_page_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
)

type ParamsAdminPageFetch struct {
ID uint `mapstructure:"id"`
SiteName string `mapstructure:"site_name"`
ID uint `mapstructure:"id"`
SiteName string

GetStatus bool `mapstructure:"get_status"`
}
Expand All @@ -26,6 +26,8 @@ func (a *action) AdminPageFetch(c echo.Context) error {
return resp
}

UseSite(c, &p.SiteName, nil, nil)

// 状态获取
if p.GetStatus {
if allPageFetching {
Expand Down
8 changes: 3 additions & 5 deletions http/admin_page_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type ParamsAdminPageGet struct {
SiteName string `mapstructure:"site_name"`
SiteName string
SiteID uint
SiteAll bool
Limit int `mapstructure:"limit"`
Expand All @@ -24,10 +24,8 @@ func (a *action) AdminPageGet(c echo.Context) error {
return resp
}

// find site
if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK {
return resp
}
// use site
UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll)

if !IsAdminHasSiteAccess(c, p.SiteName) {
return RespError(c, "无权操作")
Expand Down
8 changes: 3 additions & 5 deletions http/comment_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type ParamsAdd struct {
PageTitle string `mapstructure:"page_title"`

Token string `mapstructure:"token"`
SiteName string `mapstructure:"site_name"`
SiteName string
SiteID uint
}

Expand Down Expand Up @@ -64,10 +64,8 @@ func (a *action) Add(c echo.Context) error {
// record action for limiting action
RecordAction(c)

// find site
if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, nil); !isOK {
return resp
}
// use site
UseSite(c, &p.SiteName, &p.SiteID, nil)

// find page
page := model.FindCreatePage(p.PageKey, p.PageTitle, p.SiteName)
Expand Down
8 changes: 3 additions & 5 deletions http/comment_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

type ParamsGet struct {
PageKey string `mapstructure:"page_key" param:"required"`
SiteName string `mapstructure:"site_name"`
SiteName string

Limit int `mapstructure:"limit"`
Offset int `mapstructure:"offset"`
Expand Down Expand Up @@ -47,10 +47,8 @@ func (a *action) Get(c echo.Context) error {
return resp
}

// find site
if isOK, resp := CheckSite(c, &p.SiteName, &p.SiteID, &p.SiteAll); !isOK {
return resp
}
// use site
UseSite(c, &p.SiteName, &p.SiteID, &p.SiteAll)

// find page
var page model.Page
Expand Down
Loading

0 comments on commit 8329bee

Please sign in to comment.