diff --git a/esi/choose.go b/esi/choose.go
index 572fdb3..6b1005f 100644
--- a/esi/choose.go
+++ b/esi/choose.go
@@ -44,8 +44,7 @@ func (c *chooseTag) Process(b []byte, req *http.Request) ([]byte, int) {
for _, v := range tagIdxs {
if validateTest(v[1], req) {
res = Parse(v[2], req)
-
- break
+ return res, c.length
}
}
diff --git a/esi/include.go b/esi/include.go
index 2893e39..309598d 100644
--- a/esi/include.go
+++ b/esi/include.go
@@ -15,6 +15,18 @@ var (
altAttribute = regexp.MustCompile(`alt="?(.+?)"?( |/>)`)
)
+// safe to pass to any origin
+var headersSafe = []string{
+ "Accept",
+ "Accept-Language",
+}
+
+// safe to pass only to same-origin (same scheme, same host, same port)
+var headersUnsafe = []string{
+ "Cookie",
+ "Authorization",
+}
+
type includeTag struct {
*baseTag
src string
@@ -42,6 +54,15 @@ func sanitizeURL(u string, reqUrl *url.URL) string {
return reqUrl.ResolveReference(parsed).String()
}
+func addHeaders(headers []string, req *http.Request, rq *http.Request) {
+ for _, h := range headers {
+ v := req.Header.Get(h)
+ if v != "" {
+ rq.Header.Add(h, v)
+ }
+ }
+}
+
// Input (e.g. include src="https://domain.com/esi-include" alt="https://domain.com/alt-esi-include" />)
// With or without the alt
// With or without a space separator before the closing
@@ -59,11 +80,20 @@ func (i *includeTag) Process(b []byte, req *http.Request) ([]byte, int) {
}
rq, _ := http.NewRequest(http.MethodGet, sanitizeURL(i.src, req.URL), nil)
+ addHeaders(headersSafe, req, rq)
+ if rq.URL.Scheme == req.URL.Scheme && rq.URL.Host == req.URL.Host {
+ addHeaders(headersUnsafe, req, rq)
+ }
+
client := &http.Client{}
response, err := client.Do(rq)
- if err != nil || response.StatusCode >= 400 {
+ if (err != nil || response.StatusCode >= 400) && i.alt != "" {
rq, _ = http.NewRequest(http.MethodGet, sanitizeURL(i.alt, req.URL), nil)
+ addHeaders(headersSafe, req, rq)
+ if rq.URL.Scheme == req.URL.Scheme && rq.URL.Host == req.URL.Host {
+ addHeaders(headersUnsafe, req, rq)
+ }
response, err = client.Do(rq)
if err != nil || response.StatusCode >= 400 {
@@ -71,6 +101,10 @@ func (i *includeTag) Process(b []byte, req *http.Request) ([]byte, int) {
}
}
+ if response == nil {
+ return nil, i.length
+ }
+
defer response.Body.Close()
x, _ := io.ReadAll(response.Body)
b = Parse(x, req)
diff --git a/fixtures/full.html b/fixtures/full.html
index 1773ee8..7a040c6 100644
--- a/fixtures/full.html
+++ b/fixtures/full.html
@@ -11,5 +11,18 @@
-->
+
+
+
+
+
+
+
+
+
+
+
+
+