diff --git a/artalk-go.example.yml b/artalk-go.example.yml index bda6cbd..4486bcb 100644 --- a/artalk-go.example.yml +++ b/artalk-go.example.yml @@ -69,8 +69,7 @@ cache: db: 0 # 可信域名 -trusted_domains: - - "https://artalk.你的域名:23366" +trusted_domains: [] # 例如:["https://artalk.example.com:23366"] # SSL ssl: @@ -270,11 +269,11 @@ frontend: countEl: "#ArtalkCount" # 编辑器实时预览功能 preview: true - # 平铺模式 + # 平铺模式 ["auto", true, false] flatMode: "auto" # 最大嵌套层数 nestMax: 2 - # 嵌套评论排序规则 + # 嵌套评论排序规则 ["DATE_ASC", "DATE_DESC", "VOTE_UP_DESC"] nestSort: DATE_ASC # 头像 gravatar: @@ -298,5 +297,3 @@ frontend: reqTimeout: 15000 # 版本检测 versionCheck: true - # 语言设定 - locale: zh-CN diff --git a/http/a_http.go b/http/a_http.go index bc0be9f..61a1cc3 100644 --- a/http/a_http.go +++ b/http/a_http.go @@ -2,7 +2,6 @@ package http import ( "fmt" - libURL "net/url" "github.com/ArtalkJS/ArtalkGo/config" "github.com/ArtalkJS/ArtalkGo/lib" @@ -55,40 +54,33 @@ func Run() { } func InitCorsControl(e *echo.Echo) { - siteUrls := []string{} - for _, site := range model.FindAllSitesCooked() { - siteUrls = append(siteUrls, site.Urls...) - } - - allowOrigins := []string{} - allowOrigins = append(allowOrigins, config.Instance.TrustedDomains...) // 导入配置中的可信域名 - allowOrigins = append(allowOrigins, siteUrls...) // 导入数据库中的站点 urls - - if lib.ContainsStr(allowOrigins, "*") { - allowOrigins = []string{"*"} // 通配符关闭跨域控制 - } else { - // 提取 URL - extractURLsArr := []string{} - for _, u := range allowOrigins { - extractURLsArr = append(extractURLsArr, extractURLForCorsConf(u)) - } - // 去重 - extractURLsArr = lib.RemoveDuplicates(extractURLsArr) - allowOrigins = extractURLsArr - } - + // CORS 配置 + // for Preflight Request + // 非法 Origin 浏览器拦截继续的请求 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ - AllowOrigins: allowOrigins, AllowCredentials: true, // allow cors with cookies - })) -} + AllowOriginFunc: func(origin string) (bool, error) { + if lib.ContainsStr(config.Instance.TrustedDomains, "*") { + return true, nil // 通配符关闭 origin 检测 + } -// 从完整 URL 中提取出 Scheme + Host 部分 (保留端口号) -func extractURLForCorsConf(u string) string { - pu, err := libURL.Parse(u) - if err != nil { - return u - } + allowURLs := []string{} + allowURLs = append(allowURLs, config.Instance.TrustedDomains...) // 导入配置中的可信域名 + for _, site := range model.FindAllSitesCooked() { // 导入数据库中的站点 urls + allowURLs = append(allowURLs, site.Urls...) + } - return pu.Scheme + "://" + pu.Host + if len(allowURLs) == 0 { + // 无配置的情况全部放行 + // 如程序第一次运行的时候 + return true, nil + } + + if GetIsAllowOrigin(origin, allowURLs) { + return true, nil + } + + return false, nil + }, + })) } diff --git a/http/a_http_test.go b/http/a_http_test.go deleted file mode 100644 index 7709370..0000000 --- a/http/a_http_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package http - -import ( - "testing" -) - -func Test_extractURLForCorsConf(t *testing.T) { - tests := []struct { - name string - urls string - want string - }{ - {name: "URL with slash suffix", urls: "https://qwqaq.com/", want: "https://qwqaq.com"}, - {name: "URL with path", urls: "https://qwqaq.com/test-page/", want: "https://qwqaq.com"}, - {name: "URL with port and path", urls: "https://qwqaq.com:12345/test-page/", want: "https://qwqaq.com:12345"}, - {name: "URL with http schema", urls: "http://qwqaq.com", want: "http://qwqaq.com"}, - {name: "URL with regexp, port and path", urls: "http://*.qwqaq.com:12345/test-page/", want: "http://*.qwqaq.com:12345"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := extractURLForCorsConf(tt.urls); got != tt.want { - t.Errorf("extractURLForCorsConf() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/http/a_site_origin.go b/http/a_site_origin.go index 27188fe..495f80a 100644 --- a/http/a_site_origin.go +++ b/http/a_site_origin.go @@ -20,10 +20,11 @@ func SiteOriginMiddleware() echo.MiddlewareFunc { var site *model.Site = nil siteAll := false + isSuperAdmin := GetIsSuperAdmin(c) // 请求站点名 == "__ATK_SITE_ALL" 时取消站点隔离 if siteName == lib.ATK_SITE_ALL { - if !CheckIsAdminReq(c) { + if !isSuperAdmin { return RespError(c, "仅管理员查询允许取消站点隔离") } @@ -47,12 +48,14 @@ func SiteOriginMiddleware() echo.MiddlewareFunc { siteID = findSite.ID } - // 检测 Origin 合法性 (防止 CSRF 攻击) - if isOK, resp := CheckOrigin(c, site); !isOK { - return resp + // 检测 Origin 合法性 (防止跨域的 CSRF 攻击) + if !isSuperAdmin { // 管理员忽略 Origin 检测 + if isOK, resp := CheckOrigin(c, site); !isOK { + return resp + } } - // 设置上下文 + // 设置 Context Values c.Set(lib.CTX_KEY_ATK_SITE_ID, siteID) c.Set(lib.CTX_KEY_ATK_SITE_NAME, siteName) c.Set(lib.CTX_KEY_ATK_SITE_ALL, siteAll) @@ -75,30 +78,21 @@ func UseSite(c echo.Context, siteName *string, destID *uint, destSiteAll *bool) } // 检测 Origin 合法性 -// 防止 CSRF 攻击 +// 防止跨域的 CSRF 攻击 // @see https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) { // 可信来源 URL - allowSrcURLs := []string{} + allowURLs := []string{} // 用户配置 - allowSrcURLs = append(allowSrcURLs, config.Instance.TrustedDomains...) // 允许配置文件域名 + allowURLs = append(allowURLs, config.Instance.TrustedDomains...) // 允许配置文件域名 if allowSite != nil { - allowSrcURLs = append(allowSrcURLs, allowSite.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名 - } - if len(allowSrcURLs) == 0 { - return true, nil // 若用户配置列表中无数据,则取消控制 + allowURLs = append(allowURLs, allowSite.ToCooked().Urls...) // 允许数据库站点 URLs 中的域名 } - if lib.ContainsStr(allowSrcURLs, "*") { + if lib.ContainsStr(allowURLs, "*") { return true, nil // 列表中出现通配符关闭控制 } - host := c.Request().Host - realHostUnderProxy := c.Request().Header.Get("X-Forwarded-Host") - if realHostUnderProxy != "" { - host = realHostUnderProxy - } - // 读取 Origin 数据 // @note Origin 标头在前端 fetch POST 操作中总是携带的, // 即使配置 Referrer-Policy: no-referrer @@ -113,22 +107,41 @@ func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) { origin = referer } - pOrigin, err := url.Parse(origin) - if err != nil { - return false, RespError(c, "Origin 不合法") + // 允许同源请求 + host := c.Request().Host + realHostUnderProxy := c.Request().Header.Get("X-Forwarded-Host") + if realHostUnderProxy != "" { + host = realHostUnderProxy + } + allowURLs = append(allowURLs, c.Scheme()+"://"+host) + + // 判断 Origin 是否被允许 + if GetIsAllowOrigin(origin, allowURLs) { + return true, nil } - // 系统配置:默认允许来自相同域名的请求 - allowSrcURLs = append(allowSrcURLs, c.Scheme()+"://"+host) + return false, RespError(c, "非法请求,请检查可信域名配置") +} + +// 判断 Origin 是否被允许 +// origin is 'schema://hostname:port', +// allowURLs is a collection of url strings +func GetIsAllowOrigin(origin string, allowURLs []string) bool { + // Origin 合法性检测 + originP, err := url.Parse(origin) + if err != nil || originP.Scheme == "" || originP.Host == "" { + return false + } - allowSrcURLs = lib.RemoveDuplicates(allowSrcURLs) // 去重 - for _, a := range allowSrcURLs { - a = strings.TrimSpace(a) - if a == "" { + // 提取 URLs 检测 Origin 是否匹配 + for _, u := range allowURLs { + u = strings.TrimSpace(u) + if u == "" { continue } - pAllow, err := url.Parse(a) - if err != nil { + + urlP, err := url.Parse(u) + if err != nil || urlP.Scheme == "" || urlP.Host == "" { continue } @@ -138,10 +151,10 @@ func CheckOrigin(c echo.Context, allowSite *model.Site) (bool, error) { // Chrome v85+ 默认为:strict-origin-when-cross-origin。 // 前端页面 head 不配置 , // 浏览器默认都会至少携带 Origin 数据 (不带 path,但包含端口) - if pAllow.Host == pOrigin.Host { - return true, nil + if urlP.Scheme == originP.Scheme && urlP.Host == originP.Host { + return true } } - return false, RespError(c, "非法请求,请检查可信域名配置") + return false } diff --git a/http/a_site_origin_test.go b/http/a_site_origin_test.go new file mode 100644 index 0000000..a16563d --- /dev/null +++ b/http/a_site_origin_test.go @@ -0,0 +1,37 @@ +package http + +import ( + "testing" +) + +func Test_GetIsAllowOrigin(t *testing.T) { + tests := []struct { + name string + origin string + allowURLs []string + want bool + }{ + {name: "matched allowURLs with slash suffix", origin: "https://qwqaq.com", allowURLs: []string{"https://qwqaq.com/"}, want: true}, + {name: "matched allowURLs with path", origin: "https://qwqaq.com", allowURLs: []string{"https://qwqaq.com/test-page/"}, want: true}, + {name: "matched allowURLs with port and path", origin: "https://qwqaq.com:12345", allowURLs: []string{"https://qwqaq.com:12345/test-page/"}, want: true}, + {name: "matched allowURLs with http schema", origin: "http://qwqaq.com", allowURLs: []string{"http://qwqaq.com"}, want: true}, + + {name: "not matched, port not same", origin: "https://qwqaq.com:1234", allowURLs: []string{"https://qwqaq.com"}, want: false}, + {name: "not matched, protocol not same", origin: "http://qwqaq.com", allowURLs: []string{"https://qwqaq.com"}, want: false}, + {name: "not matched, hostname not same", origin: "https://abc.qwqaq.com", allowURLs: []string{"https://qwqaq.com"}, want: false}, + + {name: "invalid origin 1", origin: "qwqaq.com", allowURLs: []string{"https://qwqaq.com"}, want: false}, + {name: "invalid origin 2", origin: "", allowURLs: []string{"https://qwqaq.com"}, want: false}, + {name: "invalid origin 3", origin: "null", allowURLs: []string{"https://qwqaq.com"}, want: false}, + + {name: "matched multi-allowUrls", origin: "https://abc.qwqaq.com", allowURLs: []string{"https://aaaa.com", "https://bbb.com", "https://abc.qwqaq.com/abcd"}, want: true}, + {name: "not matched multi-allowUrls", origin: "https://def.qwqaq.com", allowURLs: []string{"https://aaaa.com", "https://bbb.com", "https://abc.qwqaq.com/abcd"}, want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := GetIsAllowOrigin(tt.origin, tt.allowURLs); got != tt.want { + t.Errorf("GetIsAllowOrigin() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/http/user_login.go b/http/user_login.go index 54a3ab4..41eb56a 100644 --- a/http/user_login.go +++ b/http/user_login.go @@ -30,6 +30,9 @@ func (a *action) Login(c echo.Context) error { var user model.User if p.Name == "" { // 仅 Email 的查询 + if !lib.ValidateEmail(p.Email) { + return RespError(c, "请输入正确的邮箱") + } users := model.FindUsersByEmail(p.Email) if len(users) == 1 { // 仅有一个 email 匹配的用户