Skip to content
This repository has been archived by the owner on Feb 28, 2023. It is now read-only.

Commit

Permalink
fix(trusted_domains): Extract from full URL with slash suffix & impro…
Browse files Browse the repository at this point in the history
…ve referer interceptor

Signed-off-by: qwqcode <[email protected]>
  • Loading branch information
qwqcode committed Jun 6, 2022
1 parent 1666681 commit bac1271
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 35 deletions.
24 changes: 18 additions & 6 deletions http/a_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
"fmt"
libURL "net/url"

"github.com/ArtalkJS/ArtalkGo/config"
"github.com/ArtalkJS/ArtalkGo/lib"
Expand Down Expand Up @@ -55,20 +56,21 @@ func Run() {

func InitCorsControl(e *echo.Echo) {
allowOrigins := []string{}
pushAllowOrigin := func(u string) {
if !lib.ContainsStr(allowOrigins, u) {
allowOrigins = append(allowOrigins, extractURLForCorsConf(u))
}
}

// 导入配置中的可信域名
for _, v := range config.Instance.TrustedDomains {
if !lib.ContainsStr(allowOrigins, v) {
allowOrigins = append(allowOrigins, v)
}
pushAllowOrigin(v)
}

// 导入数据库中的站点 urls
for _, site := range model.FindAllSitesCooked() {
for _, url := range site.Urls {
if !lib.ContainsStr(allowOrigins, url) {
allowOrigins = append(allowOrigins, url)
}
pushAllowOrigin(url)
}
}

Expand All @@ -81,3 +83,13 @@ func InitCorsControl(e *echo.Echo) {
AllowOrigins: allowOrigins,
}))
}

// 从完整 URL 中提取出 Scheme + Host 部分 (保留端口号)
func extractURLForCorsConf(u string) string {
pu, err := libURL.Parse(u)
if err != nil {
return u
}

return pu.Scheme + "://" + pu.Host
}
26 changes: 26 additions & 0 deletions http/a_http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package http

import (
"testing"
)

func Test_extractURLForCorsConf(t *testing.T) {
tests := []struct {
name string
urls string
want string
}{
{name: "URL with slash suffix", urls: "https://qwqaq.com/", want: "https://qwqaq.com"},
{name: "URL with path", urls: "https://qwqaq.com/test-page/", want: "https://qwqaq.com"},
{name: "URL with port and path", urls: "https://qwqaq.com:12345/test-page/", want: "https://qwqaq.com:12345"},
{name: "URL with http schema", urls: "http://qwqaq.com", want: "http://qwqaq.com"},
{name: "URL with regexp, port and path", urls: "http://*.qwqaq.com:12345/test-page/", want: "http://*.qwqaq.com:12345"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := extractURLForCorsConf(tt.urls); got != tt.want {
t.Errorf("extractURLForCorsConf() = %v, want %v", got, tt.want)
}
})
}
}
57 changes: 28 additions & 29 deletions http/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,58 +108,57 @@ func CheckIsAllowed(c echo.Context, name string, email string, page model.Page,

func CheckReferer(c echo.Context, site model.Site) (bool, error) {
isAdminReq := CheckIsAdminReq(c)
if isAdminReq || site.IsEmpty() {
return true, nil
if isAdminReq {
return true, nil // 管理员直接允许
}

// 可信域名配置
confTrustedDomains := config.Instance.TrustedDomains
// 可信来源 URL
allowReferrers := []string{}
allowReferrers = append(allowReferrers, c.Scheme()+"://"+c.Request().Host) // 默认允许来自相同域名的请求
allowReferrers = append(allowReferrers, config.Instance.TrustedDomains...) // 允许配置文件域名
allowReferrers = append(allowReferrers, site.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名

// 请求 Referer 合法性判断
if strings.TrimSpace(site.Urls) == "" && len(confTrustedDomains) == 0 {
return true, nil // 若 url 字段为空,则取消控制
if len(allowReferrers) == 0 {
return true, nil // 若列表中无数据,则取消控制
}

// 可信域名出现通配符关闭 Referer 控制
if lib.ContainsStr(confTrustedDomains, "*") {
// 列表中出现通配符关闭 Referer 控制
if lib.ContainsStr(allowReferrers, "*") {
return true, nil
}

allowUrls := site.ToCooked().Urls
if len(confTrustedDomains) != 0 {
allowUrls = append(allowUrls, confTrustedDomains...)
}

referer := c.Request().Referer()
if referer == "" {
return true, nil
return false, RespError(c, "需携带 Referer 访问,请检查前端 Referrer-Policy 设置")
}

pr, err := url.Parse(referer)
pReferer, err := url.Parse(referer)
if err != nil {
return true, nil
return false, RespError(c, "Referer 不合法")
}

allow := false
for _, u := range allowUrls {
u = strings.TrimSpace(u)
if u == "" {
for _, a := range allowReferrers {
a = strings.TrimSpace(a)
if a == "" {
continue
}
pu, err := url.Parse(u)
pAllow, err := url.Parse(a)
if err != nil {
continue
}
if pu.Hostname() == pr.Hostname() {
allow = true
break

// 在可信来源列表中匹配 Referer 的 host 部分 (含端口) 则放行
// @see https://web.dev/referrer-best-practices/
// Referrer-Policy 不能设为 no-referer,
// Chrome v85+ 默认为:strict-origin-when-cross-origin。
// 前端页面 head 不配置 <meta name="referrer" content="no-referer" />,
// 浏览器默认都会至少携带 Origin 数据 (不带 path,但包含端口)
if pAllow.Host == pReferer.Host {
return true, nil
}
}
if !allow {
return false, RespError(c, "非法请求:Referer 不被允许")
}

return true, nil
return false, RespError(c, "不允许的 Referer,请将其加入可信域名")
}

func CheckSite(c echo.Context, siteName *string, destID *uint, destSiteAll *bool) (bool, error) {
Expand Down

1 comment on commit bac1271

@qwqcode
Copy link
Member Author

@qwqcode qwqcode commented on bac1271 Jun 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related #36

Please sign in to comment.