Skip to content

Commit

Permalink
internal/middleware: move stats to its own package
Browse files Browse the repository at this point in the history
Make a package internal/middleware/stats for middleware.Stats and
middleware.ElapsedStat. This is part of removing the dependency from
internal/frontend on internal/middleware.

For golang/go#61399

Change-Id: I44afbfc9b9e28e1caabab8fe700376ec026c863d
Reviewed-on: https://go-review.googlesource.com/c/pkgsite/+/514521
TryBot-Result: Gopher Robot <[email protected]>
Run-TryBot: Michael Matloob <[email protected]>
Reviewed-by: Jamal Carvalho <[email protected]>
kokoro-CI: kokoro <[email protected]>
  • Loading branch information
matloob committed Aug 4, 2023
1 parent 051a825 commit 9e6bdc9
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 63 deletions.
4 changes: 2 additions & 2 deletions internal/frontend/details.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package frontend
import (
"context"
"errors"
mstats "golang.org/x/pkgsite/internal/middleware/stats"
"net/http"
"strings"

Expand All @@ -16,7 +17,6 @@ import (
"go.opencensus.io/tag"
"golang.org/x/mod/semver"
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/stdlib"
)

Expand All @@ -25,7 +25,7 @@ import (
// stdlib module pages are handled at "/std", and requests to "/mod/std" will
// be redirected to that path.
func (s *Server) serveDetails(w http.ResponseWriter, r *http.Request, ds internal.DataSource) (err error) {
defer middleware.ElapsedStat(r.Context(), "serveDetails")()
defer mstats.Elapsed(r.Context(), "serveDetails")()

ctx := r.Context()
if r.Method != http.MethodGet && r.Method != http.MethodHead {
Expand Down
4 changes: 2 additions & 2 deletions internal/frontend/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ import (
"golang.org/x/pkgsite/internal/godoc"
"golang.org/x/pkgsite/internal/godoc/dochtml"
"golang.org/x/pkgsite/internal/log"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
"golang.org/x/pkgsite/internal/stdlib"
)

func renderDocParts(ctx context.Context, u *internal.Unit, docPkg *godoc.Package,
nameToVersion map[string]string, bc internal.BuildContext) (_ *dochtml.Parts, err error) {
defer derrors.Wrap(&err, "renderDocParts")
defer middleware.ElapsedStat(ctx, "renderDocParts")()
defer stats.Elapsed(ctx, "renderDocParts")()

modInfo := &godoc.ModuleInfo{
ModulePath: u.ModulePath,
Expand Down
4 changes: 2 additions & 2 deletions internal/frontend/latest_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/log"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
)

// GetLatestInfo returns various pieces of information about the latest
Expand All @@ -21,7 +21,7 @@ import (
// It returns empty strings on error.
// It is intended to be used as an argument to middleware.LatestVersions.
func (s *Server) GetLatestInfo(ctx context.Context, unitPath, modulePath string, latestUnitMeta *internal.UnitMeta) internal.LatestInfo {
defer middleware.ElapsedStat(ctx, "GetLatestInfo")()
defer stats.Elapsed(ctx, "GetLatestInfo")()

// It is okay to use a different DataSource (DB connection) than the rest of the
// request, because this makes self-contained calls on the DB.
Expand Down
9 changes: 5 additions & 4 deletions internal/frontend/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"golang.org/x/pkgsite/internal/godoc/dochtml"
"golang.org/x/pkgsite/internal/log"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
"golang.org/x/pkgsite/internal/version"
"golang.org/x/text/message"
)
Expand Down Expand Up @@ -107,7 +108,7 @@ type File struct {

func fetchMainDetails(ctx context.Context, ds internal.DataSource, um *internal.UnitMeta,
requestedVersion string, expandReadme bool, bc internal.BuildContext) (_ *MainDetails, err error) {
defer middleware.ElapsedStat(ctx, "fetchMainDetails")()
defer stats.Elapsed(ctx, "fetchMainDetails")()

unit, err := ds.GetUnit(ctx, um, internal.WithMain, bc)
if err != nil {
Expand Down Expand Up @@ -146,7 +147,7 @@ func fetchMainDetails(ctx context.Context, ds internal.DataSource, um *internal.
goos = doc.GOOS
goarch = doc.GOARCH
buildContexts = unit.BuildContexts
end := middleware.ElapsedStat(ctx, "DecodePackage")
end := stats.Elapsed(ctx, "DecodePackage")
docPkg, err := godoc.DecodePackage(doc.Source)
end()
if err != nil {
Expand All @@ -167,7 +168,7 @@ func fetchMainDetails(ctx context.Context, ds internal.DataSource, um *internal.
for _, l := range docParts.Links {
docLinks = append(docLinks, link{Href: l.Href, Body: l.Text})
}
end = middleware.ElapsedStat(ctx, "sourceFiles")
end = stats.Elapsed(ctx, "sourceFiles")
files = sourceFiles(unit, docPkg)
end()
}
Expand Down Expand Up @@ -253,7 +254,7 @@ func cleanDocumentation(docs []*internal.Documentation) []*internal.Documentatio
// into an outline.
func readmeContent(ctx context.Context, u *internal.Unit) (_ *Readme, err error) {
defer derrors.Wrap(&err, "readmeContent(%q, %q, %q)", u.Path, u.ModulePath, u.Version)
defer middleware.ElapsedStat(ctx, "readmeContent")()
defer stats.Elapsed(ctx, "readmeContent")()
if !u.IsRedistributable {
return &Readme{}, nil
}
Expand Down
10 changes: 5 additions & 5 deletions internal/frontend/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
"golang.org/x/pkgsite/internal/licenses"
"golang.org/x/pkgsite/internal/log"
"golang.org/x/pkgsite/internal/memory"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
"golang.org/x/pkgsite/internal/queue"
"golang.org/x/pkgsite/internal/static"
"golang.org/x/pkgsite/internal/version"
Expand Down Expand Up @@ -198,9 +198,9 @@ func (s *Server) Install(handle func(string, http.Handler), cacher Cacher, authV
handle("/", detailHandler)
if s.serveStats {
handle("/detail-stats/",
middleware.Stats()(http.StripPrefix("/detail-stats", s.errorHandler(s.serveDetails))))
stats.Stats()(http.StripPrefix("/detail-stats", s.errorHandler(s.serveDetails))))
handle("/search-stats/",
middleware.Stats()(http.StripPrefix("/search-stats", s.errorHandler(s.serveSearch))))
stats.Stats()(http.StripPrefix("/search-stats", s.errorHandler(s.serveSearch))))
}
handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
Expand Down Expand Up @@ -672,7 +672,7 @@ func (s *Server) renderErrorPage(ctx context.Context, status int, templateName s

// servePage is used to execute all templates for a *Server.
func (s *Server) servePage(ctx context.Context, w http.ResponseWriter, templateName string, page any) {
defer middleware.ElapsedStat(ctx, "servePage")()
defer stats.Elapsed(ctx, "servePage")()

buf, err := s.renderPage(ctx, templateName, page)
if err != nil {
Expand All @@ -688,7 +688,7 @@ func (s *Server) servePage(ctx context.Context, w http.ResponseWriter, templateN

// renderPage executes the given templateName with page.
func (s *Server) renderPage(ctx context.Context, templateName string, page any) ([]byte, error) {
defer middleware.ElapsedStat(ctx, "renderPage")()
defer stats.Elapsed(ctx, "renderPage")()

tmpl, err := s.findTemplate(templateName)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions internal/frontend/unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"golang.org/x/pkgsite/internal/cookie"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/log"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
"golang.org/x/pkgsite/internal/stdlib"
"golang.org/x/pkgsite/internal/version"
"golang.org/x/pkgsite/internal/vuln"
Expand Down Expand Up @@ -106,7 +106,7 @@ type UnitPage struct {
func (s *Server) serveUnitPage(ctx context.Context, w http.ResponseWriter, r *http.Request,
ds internal.DataSource, info *urlPathInfo) (err error) {
defer derrors.Wrap(&err, "serveUnitPage(ctx, w, r, ds, %v)", info)
defer middleware.ElapsedStat(ctx, "serveUnitPage")()
defer stats.Elapsed(ctx, "serveUnitPage")()

tab := r.FormValue("tab")
if tab == "" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package middleware
package stats

import (
"context"
Expand All @@ -18,7 +18,7 @@ type statsKey struct{}

// Stats returns a Middleware that, instead of serving the page,
// serves statistics about the page.
func Stats() Middleware {
func Stats() func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sw := newStatsResponseWriter()
Expand All @@ -29,9 +29,9 @@ func Stats() Middleware {
}
}

// SetStat sets a stat named key in the current context. If key already has a
// set sets a stat named key in the current context. If key already has a
// value, the old and new value are both stored in a slice.
func SetStat(ctx context.Context, key string, value any) {
func set(ctx context.Context, key string, value any) {
x := ctx.Value(statsKey{})
if x == nil {
return
Expand All @@ -47,17 +47,17 @@ func SetStat(ctx context.Context, key string, value any) {
}
}

// ElapsedStat records as a stat the elapsed time for a
// Elapsed records as a stat the elapsed time for a
// function execution. Invoke like so:
//
// defer ElapsedStat(ctx, "FunctionName")()
// defer Elapsed(ctx, "FunctionName")()
//
// The resulting stat will be called "FunctionName ms" and will
// be the wall-clock execution time of the function in milliseconds.
func ElapsedStat(ctx context.Context, name string) func() {
func Elapsed(ctx context.Context, name string) func() {
start := time.Now()
return func() {
SetStat(ctx, name+" ms", time.Since(start).Milliseconds())
set(ctx, name+" ms", time.Since(start).Milliseconds())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package middleware
package stats

import (
"encoding/json"
Expand All @@ -23,11 +23,11 @@ func TestStats(t *testing.T) {
ts := httptest.NewServer(Stats()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
w.WriteHeader(code)
SetStat(ctx, "a", 1)
set(ctx, "a", 1)
w.Write(data[:10])
SetStat(ctx, "b", 2)
set(ctx, "b", 2)
time.Sleep(500 * time.Millisecond)
SetStat(ctx, "a", 3)
set(ctx, "a", 3)
w.Write(data[10:])
})))
defer ts.Close()
Expand Down
8 changes: 4 additions & 4 deletions internal/postgres/details.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
)

// GetNestedModules returns the latest major version of all nested modules
// given a modulePath path prefix with or without major version.
func (db *DB) GetNestedModules(ctx context.Context, modulePath string) (_ []*internal.ModuleInfo, err error) {
defer derrors.WrapStack(&err, "GetNestedModules(ctx, %v)", modulePath)
defer middleware.ElapsedStat(ctx, "GetNestedModules")()
defer stats.Elapsed(ctx, "GetNestedModules")()

query := `
SELECT DISTINCT ON (series_path)
Expand Down Expand Up @@ -78,7 +78,7 @@ func (db *DB) GetNestedModules(ctx context.Context, modulePath string) (_ []*int
// Instead of supporting pagination, this query runs with a limit.
func (db *DB) GetImportedBy(ctx context.Context, pkgPath, modulePath string, limit int) (paths []string, err error) {
defer derrors.WrapStack(&err, "GetImportedBy(ctx, %q, %q)", pkgPath, modulePath)
defer middleware.ElapsedStat(ctx, "GetImportedBy")()
defer stats.Elapsed(ctx, "GetImportedBy")()

if pkgPath == "" {
return nil, fmt.Errorf("pkgPath cannot be empty: %w", derrors.InvalidArgument)
Expand All @@ -102,7 +102,7 @@ func (db *DB) GetImportedBy(ctx context.Context, pkgPath, modulePath string, lim
// GetImportedByCount returns the number of packages that import pkgPath.
func (db *DB) GetImportedByCount(ctx context.Context, pkgPath, modulePath string) (_ int, err error) {
defer derrors.WrapStack(&err, "GetImportedByCount(ctx, %q, %q)", pkgPath, modulePath)
defer middleware.ElapsedStat(ctx, "GetImportedByCount")()
defer stats.Elapsed(ctx, "GetImportedByCount")()

if pkgPath == "" {
return 0, fmt.Errorf("pkgPath cannot be empty: %w", derrors.InvalidArgument)
Expand Down
4 changes: 2 additions & 2 deletions internal/postgres/licenses.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ import (
"github.com/lib/pq"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/licenses"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
"golang.org/x/pkgsite/internal/stdlib"
)

func (db *DB) getLicenses(ctx context.Context, fullPath, modulePath string, unitID int) (_ []*licenses.License, err error) {
defer derrors.WrapStack(&err, "getLicenses(ctx, %d)", unitID)
defer middleware.ElapsedStat(ctx, "getLicenses")()
defer stats.Elapsed(ctx, "getLicenses")()

query := `
SELECT
Expand Down
4 changes: 2 additions & 2 deletions internal/postgres/package_symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import (
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
)

// getPackageSymbols returns all of the symbols for a given package path and module path.
func getPackageSymbols(ctx context.Context, ddb *database.DB, packagePath, modulePath string,
) (_ *internal.SymbolHistory, err error) {
defer derrors.Wrap(&err, "getPackageSymbols(ctx, ddb, %q, %q)", packagePath, modulePath)
defer middleware.ElapsedStat(ctx, "getPackageSymbols")()
defer stats.Elapsed(ctx, "getPackageSymbols")()

query := packageSymbolQueryJoin(
squirrel.Select(
Expand Down
8 changes: 4 additions & 4 deletions internal/postgres/symbol_history.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"golang.org/x/pkgsite/internal"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
"golang.org/x/pkgsite/internal/symbol"
)

Expand All @@ -22,7 +22,7 @@ import (
func (db *DB) GetSymbolHistory(ctx context.Context, packagePath, modulePath string,
) (_ *internal.SymbolHistory, err error) {
defer derrors.Wrap(&err, "GetSymbolHistory(ctx, %q, %q)", packagePath, modulePath)
defer middleware.ElapsedStat(ctx, "GetSymbolHistory")()
defer stats.Elapsed(ctx, "GetSymbolHistory")()

return GetSymbolHistoryFromTable(ctx, db.db, packagePath, modulePath)
}
Expand Down Expand Up @@ -71,7 +71,7 @@ func GetSymbolHistoryFromTable(ctx context.Context, ddb *database.DB,
func GetSymbolHistoryWithPackageSymbols(ctx context.Context, ddb *database.DB,
packagePath, modulePath string) (_ *internal.SymbolHistory, err error) {
defer derrors.WrapStack(&err, "GetSymbolHistoryWithPackageSymbols(ctx, ddb, %q, %q)", packagePath, modulePath)
defer middleware.ElapsedStat(ctx, "GetSymbolHistoryWithPackageSymbols")()
defer stats.Elapsed(ctx, "GetSymbolHistoryWithPackageSymbols")()
sh, err := getPackageSymbols(ctx, ddb, packagePath, modulePath)
if err != nil {
return nil, err
Expand All @@ -86,7 +86,7 @@ func GetSymbolHistoryWithPackageSymbols(ctx context.Context, ddb *database.DB,
func GetSymbolHistoryForBuildContext(ctx context.Context, ddb *database.DB, pathID int, modulePath string,
bc internal.BuildContext) (_ map[string]string, err error) {
defer derrors.WrapStack(&err, "GetSymbolHistoryForBuildContext(ctx, ddb, %d, %q)", pathID, modulePath)
defer middleware.ElapsedStat(ctx, "GetSymbolHistoryForBuildContext")()
defer stats.Elapsed(ctx, "GetSymbolHistoryForBuildContext")()

if bc == internal.BuildContextAll {
bc = internal.BuildContextLinux
Expand Down
12 changes: 6 additions & 6 deletions internal/postgres/symbolsearch.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ import (
"github.com/lib/pq"
"golang.org/x/pkgsite/internal/database"
"golang.org/x/pkgsite/internal/derrors"
"golang.org/x/pkgsite/internal/middleware"
"golang.org/x/pkgsite/internal/middleware/stats"
"golang.org/x/pkgsite/internal/postgres/search"
"golang.org/x/sync/errgroup"
)

func upsertSymbolSearchDocuments(ctx context.Context, tx *database.DB,
modulePath, v string) (err error) {
defer derrors.Wrap(&err, "upsertSymbolSearchDocuments(ctx, ddb, %q, %q)", modulePath, v)
defer middleware.ElapsedStat(ctx, "upsertSymbolSearchDocuments")()
defer stats.Elapsed(ctx, "upsertSymbolSearchDocuments")()

// If a user is looking for the symbol "DB.Begin", from package
// database/sql, we want them to be able to find this by searching for
Expand Down Expand Up @@ -97,7 +97,7 @@ func upsertSymbolSearchDocuments(ctx context.Context, tx *database.DB,
// TODO(https://golang.org/issue/44142): factor out common code between
// symbolSearch and deepSearch.
func (db *DB) symbolSearch(ctx context.Context, q string, limit int, opts SearchOptions) searchResponse {
defer middleware.ElapsedStat(ctx, "symbolSearch")()
defer stats.Elapsed(ctx, "symbolSearch")()

var (
results []*SearchResult
Expand Down Expand Up @@ -156,7 +156,7 @@ func runSymbolSearchMultiWord(ctx context.Context, ddb *database.DB, q string, l
symbolFilter string) (_ []*SearchResult, err error) {
defer derrors.Wrap(&err, "runSymbolSearchMultiWord(ctx, ddb, query, %q, %d, %q)",
q, limit, symbolFilter)
defer middleware.ElapsedStat(ctx, "runSymbolSearchMultiWord")()
defer stats.Elapsed(ctx, "runSymbolSearchMultiWord")()

symbolToPathTokens := multiwordSearchCombinations(q, symbolFilter)
if len(symbolToPathTokens) == 0 {
Expand Down Expand Up @@ -259,7 +259,7 @@ func multiwordSearchCombinations(q, symbolFilter string) map[string]string {
// when using an OR in the WHERE clause.
func runSymbolSearchOneDot(ctx context.Context, ddb *database.DB, q string, limit int) (_ []*SearchResult, err error) {
defer derrors.Wrap(&err, "runSymbolSearchOneDot(ctx, ddb, %q, %d)", q, limit)
defer middleware.ElapsedStat(ctx, "runSymbolSearchOneDot")()
defer stats.Elapsed(ctx, "runSymbolSearchOneDot")()

group, searchCtx := errgroup.WithContext(ctx)
resultsArray := make([][]*SearchResult, 2)
Expand Down Expand Up @@ -318,7 +318,7 @@ func splitPackageAndSymbolNames(q string) (pkgName string, symbolName string, er
func runSymbolSearch(ctx context.Context, ddb *database.DB,
st search.SearchType, q string, limit int, args ...any) (results []*SearchResult, err error) {
defer derrors.Wrap(&err, "runSymbolSearch(ctx, ddb, %q, %q, %d, %v)", st, q, limit, args)
defer middleware.ElapsedStat(ctx, fmt.Sprintf("%s-runSymbolSearch", st))()
defer stats.Elapsed(ctx, fmt.Sprintf("%s-runSymbolSearch", st))()

collect := func(rows *sql.Rows) error {
var r SearchResult
Expand Down
Loading

0 comments on commit 9e6bdc9

Please sign in to comment.