From bac1271a96881e02d36f5ab8f087e771e694a877 Mon Sep 17 00:00:00 2001 From: qwqcode Date: Mon, 6 Jun 2022 21:48:20 +0800 Subject: [PATCH] fix(trusted_domains): Extract from full URL with slash suffix & improve referer interceptor Signed-off-by: qwqcode --- http/a_http.go | 24 ++++++++++++++----- http/a_http_test.go | 26 +++++++++++++++++++++ http/utils.go | 57 ++++++++++++++++++++++----------------------- 3 files changed, 72 insertions(+), 35 deletions(-) create mode 100644 http/a_http_test.go diff --git a/http/a_http.go b/http/a_http.go index d255231..b4d593e 100644 --- a/http/a_http.go +++ b/http/a_http.go @@ -2,6 +2,7 @@ package http import ( "fmt" + libURL "net/url" "github.com/ArtalkJS/ArtalkGo/config" "github.com/ArtalkJS/ArtalkGo/lib" @@ -55,20 +56,21 @@ func Run() { func InitCorsControl(e *echo.Echo) { allowOrigins := []string{} + pushAllowOrigin := func(u string) { + if !lib.ContainsStr(allowOrigins, u) { + allowOrigins = append(allowOrigins, extractURLForCorsConf(u)) + } + } // 导入配置中的可信域名 for _, v := range config.Instance.TrustedDomains { - if !lib.ContainsStr(allowOrigins, v) { - allowOrigins = append(allowOrigins, v) - } + pushAllowOrigin(v) } // 导入数据库中的站点 urls for _, site := range model.FindAllSitesCooked() { for _, url := range site.Urls { - if !lib.ContainsStr(allowOrigins, url) { - allowOrigins = append(allowOrigins, url) - } + pushAllowOrigin(url) } } @@ -81,3 +83,13 @@ func InitCorsControl(e *echo.Echo) { AllowOrigins: allowOrigins, })) } + +// 从完整 URL 中提取出 Scheme + Host 部分 (保留端口号) +func extractURLForCorsConf(u string) string { + pu, err := libURL.Parse(u) + if err != nil { + return u + } + + return pu.Scheme + "://" + pu.Host +} diff --git a/http/a_http_test.go b/http/a_http_test.go new file mode 100644 index 0000000..7709370 --- /dev/null +++ b/http/a_http_test.go @@ -0,0 +1,26 @@ +package http + +import ( + "testing" +) + +func Test_extractURLForCorsConf(t *testing.T) { + tests := []struct { + name string + urls string + want string + }{ + {name: "URL with slash suffix", urls: "https://qwqaq.com/", want: "https://qwqaq.com"}, + {name: "URL with path", urls: "https://qwqaq.com/test-page/", want: "https://qwqaq.com"}, + {name: "URL with port and path", urls: "https://qwqaq.com:12345/test-page/", want: "https://qwqaq.com:12345"}, + {name: "URL with http schema", urls: "http://qwqaq.com", want: "http://qwqaq.com"}, + {name: "URL with regexp, port and path", urls: "http://*.qwqaq.com:12345/test-page/", want: "http://*.qwqaq.com:12345"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := extractURLForCorsConf(tt.urls); got != tt.want { + t.Errorf("extractURLForCorsConf() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/http/utils.go b/http/utils.go index fbe246a..a5bb6ab 100644 --- a/http/utils.go +++ b/http/utils.go @@ -108,58 +108,57 @@ func CheckIsAllowed(c echo.Context, name string, email string, page model.Page, func CheckReferer(c echo.Context, site model.Site) (bool, error) { isAdminReq := CheckIsAdminReq(c) - if isAdminReq || site.IsEmpty() { - return true, nil + if isAdminReq { + return true, nil // 管理员直接允许 } - // 可信域名配置 - confTrustedDomains := config.Instance.TrustedDomains + // 可信来源 URL + allowReferrers := []string{} + allowReferrers = append(allowReferrers, c.Scheme()+"://"+c.Request().Host) // 默认允许来自相同域名的请求 + allowReferrers = append(allowReferrers, config.Instance.TrustedDomains...) // 允许配置文件域名 + allowReferrers = append(allowReferrers, site.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名 - // 请求 Referer 合法性判断 - if strings.TrimSpace(site.Urls) == "" && len(confTrustedDomains) == 0 { - return true, nil // 若 url 字段为空,则取消控制 + if len(allowReferrers) == 0 { + return true, nil // 若列表中无数据,则取消控制 } - // 可信域名出现通配符关闭 Referer 控制 - if lib.ContainsStr(confTrustedDomains, "*") { + // 列表中出现通配符关闭 Referer 控制 + if lib.ContainsStr(allowReferrers, "*") { return true, nil } - allowUrls := site.ToCooked().Urls - if len(confTrustedDomains) != 0 { - allowUrls = append(allowUrls, confTrustedDomains...) - } - referer := c.Request().Referer() if referer == "" { - return true, nil + return false, RespError(c, "需携带 Referer 访问,请检查前端 Referrer-Policy 设置") } - pr, err := url.Parse(referer) + pReferer, err := url.Parse(referer) if err != nil { - return true, nil + return false, RespError(c, "Referer 不合法") } - allow := false - for _, u := range allowUrls { - u = strings.TrimSpace(u) - if u == "" { + for _, a := range allowReferrers { + a = strings.TrimSpace(a) + if a == "" { continue } - pu, err := url.Parse(u) + pAllow, err := url.Parse(a) if err != nil { continue } - if pu.Hostname() == pr.Hostname() { - allow = true - break + + // 在可信来源列表中匹配 Referer 的 host 部分 (含端口) 则放行 + // @see https://web.dev/referrer-best-practices/ + // Referrer-Policy 不能设为 no-referer, + // Chrome v85+ 默认为:strict-origin-when-cross-origin。 + // 前端页面 head 不配置 , + // 浏览器默认都会至少携带 Origin 数据 (不带 path,但包含端口) + if pAllow.Host == pReferer.Host { + return true, nil } } - if !allow { - return false, RespError(c, "非法请求:Referer 不被允许") - } - return true, nil + return false, RespError(c, "不允许的 Referer,请将其加入可信域名") } func CheckSite(c echo.Context, siteName *string, destID *uint, destSiteAll *bool) (bool, error) {