Skip to content
This repository has been archived by the owner on Feb 28, 2023. It is now read-only.

Commit

Permalink
refactor: http origin checker
Browse files Browse the repository at this point in the history
Signed-off-by: qwqcode <[email protected]>
  • Loading branch information
qwqcode committed Oct 29, 2022
1 parent 1c095c3 commit 4494e53
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 98 deletions.
9 changes: 3 additions & 6 deletions artalk-go.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ cache:
db: 0

# 可信域名
trusted_domains:
- "https://artalk.你的域名:23366"
trusted_domains: [] # 例如:["https://artalk.example.com:23366"]

# SSL
ssl:
Expand Down Expand Up @@ -270,11 +269,11 @@ frontend:
countEl: "#ArtalkCount"
# 编辑器实时预览功能
preview: true
# 平铺模式
# 平铺模式 ["auto", true, false]
flatMode: "auto"
# 最大嵌套层数
nestMax: 2
# 嵌套评论排序规则
# 嵌套评论排序规则 ["DATE_ASC", "DATE_DESC", "VOTE_UP_DESC"]
nestSort: DATE_ASC
# 头像
gravatar:
Expand All @@ -298,5 +297,3 @@ frontend:
reqTimeout: 15000
# 版本检测
versionCheck: true
# 语言设定
locale: zh-CN
58 changes: 25 additions & 33 deletions http/a_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package http

import (
"fmt"
libURL "net/url"

"github.com/ArtalkJS/ArtalkGo/config"
"github.com/ArtalkJS/ArtalkGo/lib"
Expand Down Expand Up @@ -55,40 +54,33 @@ func Run() {
}

func InitCorsControl(e *echo.Echo) {
siteUrls := []string{}
for _, site := range model.FindAllSitesCooked() {
siteUrls = append(siteUrls, site.Urls...)
}

allowOrigins := []string{}
allowOrigins = append(allowOrigins, config.Instance.TrustedDomains...) // 导入配置中的可信域名
allowOrigins = append(allowOrigins, siteUrls...) // 导入数据库中的站点 urls

if lib.ContainsStr(allowOrigins, "*") {
allowOrigins = []string{"*"} // 通配符关闭跨域控制
} else {
// 提取 URL
extractURLsArr := []string{}
for _, u := range allowOrigins {
extractURLsArr = append(extractURLsArr, extractURLForCorsConf(u))
}
// 去重
extractURLsArr = lib.RemoveDuplicates(extractURLsArr)
allowOrigins = extractURLsArr
}

// CORS 配置
// for Preflight Request
// 非法 Origin 浏览器拦截继续的请求
e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: allowOrigins,
AllowCredentials: true, // allow cors with cookies
}))
}
AllowOriginFunc: func(origin string) (bool, error) {
if lib.ContainsStr(config.Instance.TrustedDomains, "*") {
return true, nil // 通配符关闭 origin 检测
}

// 从完整 URL 中提取出 Scheme + Host 部分 (保留端口号)
func extractURLForCorsConf(u string) string {
pu, err := libURL.Parse(u)
if err != nil {
return u
}
allowURLs := []string{}
allowURLs = append(allowURLs, config.Instance.TrustedDomains...) // 导入配置中的可信域名
for _, site := range model.FindAllSitesCooked() { // 导入数据库中的站点 urls
allowURLs = append(allowURLs, site.Urls...)
}

return pu.Scheme + "://" + pu.Host
if len(allowURLs) == 0 {
// 无配置的情况全部放行
// 如程序第一次运行的时候
return true, nil
}

if GetIsAllowOrigin(origin, allowURLs) {
return true, nil
}

return false, nil
},
}))
}
26 changes: 0 additions & 26 deletions http/a_http_test.go

This file was deleted.

79 changes: 46 additions & 33 deletions http/a_site_origin.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ func SiteOriginMiddleware() echo.MiddlewareFunc {
var site *model.Site = nil

siteAll := false
isSuperAdmin := GetIsSuperAdmin(c)

// 请求站点名 == "__ATK_SITE_ALL" 时取消站点隔离
if siteName == lib.ATK_SITE_ALL {
if !CheckIsAdminReq(c) {
if !isSuperAdmin {
return RespError(c, "仅管理员查询允许取消站点隔离")
}

Expand All @@ -47,12 +48,14 @@ func SiteOriginMiddleware() echo.MiddlewareFunc {
siteID = findSite.ID
}

// 检测 Origin 合法性 (防止 CSRF 攻击)
if isOK, resp := CheckOrigin(c, site); !isOK {
return resp
// 检测 Origin 合法性 (防止跨域的 CSRF 攻击)
if !isSuperAdmin { // 管理员忽略 Origin 检测
if isOK, resp := CheckOrigin(c, site); !isOK {
return resp
}
}

// 设置上下文
// 设置 Context Values
c.Set(lib.CTX_KEY_ATK_SITE_ID, siteID)
c.Set(lib.CTX_KEY_ATK_SITE_NAME, siteName)
c.Set(lib.CTX_KEY_ATK_SITE_ALL, siteAll)
Expand All @@ -75,30 +78,21 @@ func UseSite(c echo.Context, siteName *string, destID *uint, destSiteAll *bool)
}

// 检测 Origin 合法性
// 防止 CSRF 攻击
// 防止跨域的 CSRF 攻击
// @see https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html
func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) {
// 可信来源 URL
allowSrcURLs := []string{}
allowURLs := []string{}

// 用户配置
allowSrcURLs = append(allowSrcURLs, config.Instance.TrustedDomains...) // 允许配置文件域名
allowURLs = append(allowURLs, config.Instance.TrustedDomains...) // 允许配置文件域名
if allowSite != nil {
allowSrcURLs = append(allowSrcURLs, allowSite.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名
}
if len(allowSrcURLs) == 0 {
return true, nil // 若用户配置列表中无数据,则取消控制
allowURLs = append(allowURLs, allowSite.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名
}
if lib.ContainsStr(allowSrcURLs, "*") {
if lib.ContainsStr(allowURLs, "*") {
return true, nil // 列表中出现通配符关闭控制
}

host := c.Request().Host
realHostUnderProxy := c.Request().Header.Get("X-Forwarded-Host")
if realHostUnderProxy != "" {
host = realHostUnderProxy
}

// 读取 Origin 数据
// @note Origin 标头在前端 fetch POST 操作中总是携带的,
// 即使配置 Referrer-Policy: no-referrer
Expand All @@ -113,22 +107,41 @@ func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) {
origin = referer
}

pOrigin, err := url.Parse(origin)
if err != nil {
return false, RespError(c, "Origin 不合法")
// 允许同源请求
host := c.Request().Host
realHostUnderProxy := c.Request().Header.Get("X-Forwarded-Host")
if realHostUnderProxy != "" {
host = realHostUnderProxy
}
allowURLs = append(allowURLs, c.Scheme()+"://"+host)

// 判断 Origin 是否被允许
if GetIsAllowOrigin(origin, allowURLs) {
return true, nil
}

// 系统配置:默认允许来自相同域名的请求
allowSrcURLs = append(allowSrcURLs, c.Scheme()+"://"+host)
return false, RespError(c, "非法请求,请检查可信域名配置")
}

// 判断 Origin 是否被允许
// origin is 'schema://hostname:port',
// allowURLs is a collection of url strings
func GetIsAllowOrigin(origin string, allowURLs []string) bool {
// Origin 合法性检测
originP, err := url.Parse(origin)
if err != nil || originP.Scheme == "" || originP.Host == "" {
return false
}

allowSrcURLs = lib.RemoveDuplicates(allowSrcURLs) // 去重
for _, a := range allowSrcURLs {
a = strings.TrimSpace(a)
if a == "" {
// 提取 URLs 检测 Origin 是否匹配
for _, u := range allowURLs {
u = strings.TrimSpace(u)
if u == "" {
continue
}
pAllow, err := url.Parse(a)
if err != nil {

urlP, err := url.Parse(u)
if err != nil || urlP.Scheme == "" || urlP.Host == "" {
continue
}

Expand All @@ -138,10 +151,10 @@ func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) {
// Chrome v85+ 默认为:strict-origin-when-cross-origin。
// 前端页面 head 不配置 <meta name="referrer" content="no-referer" />,
// 浏览器默认都会至少携带 Origin 数据 (不带 path,但包含端口)
if pAllow.Host == pOrigin.Host {
return true, nil
if urlP.Scheme == originP.Scheme && urlP.Host == originP.Host {
return true
}
}

return false, RespError(c, "非法请求,请检查可信域名配置")
return false
}
37 changes: 37 additions & 0 deletions http/a_site_origin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package http

import (
"testing"
)

func Test_GetIsAllowOrigin(t *testing.T) {
tests := []struct {
name string
origin string
allowURLs []string
want bool
}{
{name: "matched allowURLs with slash suffix", origin: "https://qwqaq.com", allowURLs: []string{"https://qwqaq.com/"}, want: true},
{name: "matched allowURLs with path", origin: "https://qwqaq.com", allowURLs: []string{"https://qwqaq.com/test-page/"}, want: true},
{name: "matched allowURLs with port and path", origin: "https://qwqaq.com:12345", allowURLs: []string{"https://qwqaq.com:12345/test-page/"}, want: true},
{name: "matched allowURLs with http schema", origin: "http://qwqaq.com", allowURLs: []string{"http://qwqaq.com"}, want: true},

{name: "not matched, port not same", origin: "https://qwqaq.com:1234", allowURLs: []string{"https://qwqaq.com"}, want: false},
{name: "not matched, protocol not same", origin: "http://qwqaq.com", allowURLs: []string{"https://qwqaq.com"}, want: false},
{name: "not matched, hostname not same", origin: "https://abc.qwqaq.com", allowURLs: []string{"https://qwqaq.com"}, want: false},

{name: "invalid origin 1", origin: "qwqaq.com", allowURLs: []string{"https://qwqaq.com"}, want: false},
{name: "invalid origin 2", origin: "", allowURLs: []string{"https://qwqaq.com"}, want: false},
{name: "invalid origin 3", origin: "null", allowURLs: []string{"https://qwqaq.com"}, want: false},

{name: "matched multi-allowUrls", origin: "https://abc.qwqaq.com", allowURLs: []string{"https://aaaa.com", "https://bbb.com", "https://abc.qwqaq.com/abcd"}, want: true},
{name: "not matched multi-allowUrls", origin: "https://def.qwqaq.com", allowURLs: []string{"https://aaaa.com", "https://bbb.com", "https://abc.qwqaq.com/abcd"}, want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetIsAllowOrigin(tt.origin, tt.allowURLs); got != tt.want {
t.Errorf("GetIsAllowOrigin() = %v, want %v", got, tt.want)
}
})
}
}
3 changes: 3 additions & 0 deletions http/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ func (a *action) Login(c echo.Context) error {
var user model.User
if p.Name == "" {
// 仅 Email 的查询
if !lib.ValidateEmail(p.Email) {
return RespError(c, "请输入正确的邮箱")
}
users := model.FindUsersByEmail(p.Email)
if len(users) == 1 {
// 仅有一个 email 匹配的用户
Expand Down

0 comments on commit 4494e53

Please sign in to comment.