From 95ac589f8c4450381cd804175a7494f52e9d34d0 Mon Sep 17 00:00:00 2001 From: kkHAIKE Date: Mon, 19 Sep 2022 17:10:33 +0800 Subject: [PATCH] auto mark server-side request --- contextcheck.go | 61 ++++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/contextcheck.go b/contextcheck.go index 573475b..63a5d30 100644 --- a/contextcheck.go +++ b/contextcheck.go @@ -75,6 +75,8 @@ type resInfo struct { // reuse for doc ReqCtx bool Skip bool + + EntryType entryType } type ctxFact map[string]resInfo @@ -234,10 +236,18 @@ func (r *runner) noImportedContextAndHttp(f *ssa.Function) (ret bool) { return true } -func (r *runner) checkIsEntry(f *ssa.Function) entryType { +func (r *runner) checkIsEntry(f *ssa.Function) (ret entryType) { // if r.noImportedContextAndHttp(f) { // return EntryNormal // } + key := "entry:" + f.RelString(nil) + res, ok := r.getValue(key, f) + if ok { + return res.EntryType + } + defer func() { + r.currentFact[key] = resInfo{EntryType: ret} + }() ctxIn, ctxOut := r.checkIsCtx(f) if ctxOut { @@ -264,21 +274,14 @@ func (r *runner) checkIsEntry(f *ssa.Function) entryType { } func (r *runner) docFlag(f *ssa.Function) (reqctx, skip bool) { - key := "doc:" + f.RelString(nil) - res, ok := r.getValue(key, f) - if ok { - return res.ReqCtx, res.Skip - } - for _, v := range r.getDocFromFunc(f) { if len(nolintRe.FindString(v.Text)) > 0 && strings.Contains(v.Text, "contextcheck") { - res.Skip = true + skip = true } else if strings.HasPrefix(v.Text, "// @contextcheck(req_has_ctx)") { - res.ReqCtx = true + reqctx = true } } - r.currentFact[key] = res - return res.ReqCtx, res.Skip + return } var nolintRe = regexp.MustCompile(`^//\s?nolint:`) @@ -333,22 +336,30 @@ func (r *runner) checkIsCtx(f *ssa.Function) (in, out bool) { } func (r *runner) checkIsHttpHandler(f *ssa.Function, reqctx bool) bool { - if reqctx { - tuple := f.Signature.Params() - for i := 0; i < tuple.Len(); i++ { - if r.isHttpReqType(tuple.At(i).Type()) { - return true - } + var hasReq bool + tuple := f.Signature.Params() + for i := 0; i < tuple.Len(); i++ { + if r.isHttpReqType(tuple.At(i).Type()) { + hasReq = true + break } } - - // must has no result - if f.Signature.Results().Len() > 0 { + if !hasReq { return false } + if reqctx { + return true + } + + // check if use r.Context() + if f.Blocks != nil && len(r.getHttpReqCtx(f, true)) > 0 { + return true + } // must be `func f(w http.ResponseWriter, r *http.Request) {}` - tuple := f.Signature.Params() + if f.Signature.Results().Len() > 0 { + return false + } if tuple.Len() != 2 { return false } @@ -420,7 +431,7 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[ } if isHttpHandler { - for _, v := range r.getHttpReqCtx(f) { + for _, v := range r.getHttpReqCtx(f, false) { checkRefs(v, false) } } else { @@ -456,7 +467,7 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[ return } -func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) { +func (r *runner) getHttpReqCtx(f *ssa.Function, least1 bool) (rets []ssa.Value) { checkedRefMap := make(map[ssa.Value]bool) var checkRefs func(val ssa.Value, fromAddr bool) @@ -498,6 +509,9 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) { if f.Signature.Recv() != nil { // collect the return of r.Context rets = append(rets, i.Value()) + if least1 { + return + } } case *ssa.Store: if !fromAddr { @@ -516,7 +530,6 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) { for _, param := range f.Params { if r.isHttpReqType(param.Type()) { checkRefs(param, false) - break } }