Skip to content

Commit

Permalink
Added context.Context as an argument or a field of struct to the exis…
Browse files Browse the repository at this point in the history
…ting functions or structs to enable Contextual Logging
  • Loading branch information
bells17 committed Sep 16, 2023
1 parent f942a75 commit 1b7a924
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 49 deletions.
33 changes: 16 additions & 17 deletions connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ func SetMaxGRPCLogLength(characterCount int) {
//
// For other connections, the default behavior from gRPC is used and
// loss of connection is not detected reliably.
func Connect(address string, metricsManager metrics.CSIMetricsManager, options ...Option) (*grpc.ClientConn, error) {
func Connect(ctx context.Context, address string, metricsManager metrics.CSIMetricsManager, options ...Option) (*grpc.ClientConn, error) {
// Prepend default options
options = append([]Option{WithTimeout(time.Second * 30)}, options...)
if metricsManager != nil {
options = append([]Option{WithMetrics(metricsManager)}, options...)
}
return connect(address, options)
return connect(ctx, address, options)
}

// ConnectWithoutMetrics behaves exactly like Connect except no metrics are recorded.
// This function is deprecated, prefer using Connect with `nil` as the metricsManager.
func ConnectWithoutMetrics(address string, options ...Option) (*grpc.ClientConn, error) {
func ConnectWithoutMetrics(ctx context.Context, address string, options ...Option) (*grpc.ClientConn, error) {
// Prepend default options
options = append([]Option{WithTimeout(time.Second * 30)}, options...)
return connect(address, options)
return connect(ctx, address, options)
}

// Option is the type of all optional parameters for Connect.
Expand All @@ -104,13 +104,13 @@ func OnConnectionLoss(reconnect func() bool) Option {

// ExitOnConnectionLoss returns callback for OnConnectionLoss() that writes
// an error to /dev/termination-log and exits.
func ExitOnConnectionLoss() func() bool {
return func() bool {
func ExitOnConnectionLoss() func(context.Context) bool {
return func(ctx context.Context) bool {
terminationMsg := "Lost connection to CSI driver, exiting"
if err := os.WriteFile(terminationLogPath, []byte(terminationMsg), 0644); err != nil {
klog.Background().Error(err, "Failed to write a message to the termination logfile", "terminationLogPath", terminationLogPath)
klog.FromContext(ctx).Error(err, "Failed to write a message to the termination logfile", "terminationLogPath", terminationLogPath)
}
klog.Background().Error(nil, terminationMsg)
klog.FromContext(ctx).Error(nil, terminationMsg)
klog.FlushAndExit(klog.ExitFlushTimeout, 1)
// Not reached.
return false
Expand Down Expand Up @@ -147,6 +147,7 @@ type options struct {

// connect is the internal implementation of Connect. It has more options to enable testing.
func connect(
ctx context.Context,
address string,
connectOptions []Option) (*grpc.ClientConn, error) {
var o options
Expand Down Expand Up @@ -189,7 +190,7 @@ func connect(
if haveConnected && !lostConnection {
// We have detected a loss of connection for the first time. Decide what to do...
// Record this once. TODO (?): log at regular time intervals.
klog.Background().Error(nil, "Lost connection", "address", address)
klog.FromContext(ctx).Error(nil, "Lost connection", "address", address)
// Inform caller and let it decide? Default is to reconnect.
if o.reconnect != nil {
reconnect = o.reconnect()
Expand All @@ -211,7 +212,7 @@ func connect(
return nil, errors.New("OnConnectionLoss callback only supported for unix:// addresses")
}

klog.Background().V(5).Info("Connecting", "address", address)
klog.FromContext(ctx).V(5).Info("Connecting", "address", address)

// Connect in background.
var conn *grpc.ClientConn
Expand All @@ -230,7 +231,7 @@ func connect(
for {
select {
case <-ticker.C:
klog.Background().Info("Still connecting", "address", address)
klog.FromContext(ctx).Info("Still connecting", "address", address)

case <-ready:
return conn, err
Expand All @@ -240,14 +241,13 @@ func connect(

// LogGRPC is gPRC unary interceptor for logging of CSI messages at level 5. It removes any secrets from the message.
func LogGRPC(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
logger := klog.FromContext(ctx)
logger.V(5).Info("GRPC call", "method", method, "request", protosanitizer.StripSecrets(req))
klog.FromContext(ctx).V(5).Info("GRPC call", "method", method, "request", protosanitizer.StripSecrets(req))
err := invoker(ctx, method, req, reply, cc, opts...)
cappedStr := protosanitizer.StripSecrets(reply).String()
if maxLogChar > 0 && len(cappedStr) > maxLogChar {
cappedStr = cappedStr[:maxLogChar] + fmt.Sprintf(" [response body too large, log capped to %d chars]", maxLogChar)
}
logger.V(5).Info("GRPC response", "response", cappedStr, "err", err)
klog.FromContext(ctx).V(5).Info("GRPC response", "response", cappedStr, "err", err)
return err
}

Expand All @@ -274,7 +274,6 @@ func (cmm ExtendedCSIMetricsManager) RecordMetricsClientInterceptor(
start := time.Now()
err := invoker(ctx, method, req, reply, cc, opts...)
duration := time.Since(start)
logger := klog.FromContext(ctx)

var cmmBase metrics.CSIMetricsManager
cmmBase = cmm
Expand All @@ -285,14 +284,14 @@ func (cmm ExtendedCSIMetricsManager) RecordMetricsClientInterceptor(
if additionalInfo != nil {
additionalInfoVal, ok := additionalInfo.(AdditionalInfo)
if !ok {
logger.Error(nil, "Failed to record migrated status, cannot convert additional info", "additionalInfo", additionalInfo)
klog.FromContext(ctx).Error(nil, "Failed to record migrated status, cannot convert additional info", "additionalInfo", additionalInfo)
return err
}
migrated = additionalInfoVal.Migrated
}
cmmv, metricsErr := cmm.WithLabelValues(map[string]string{metrics.LabelMigrated: migrated})
if metricsErr != nil {
logger.Error(metricsErr, "Failed to record migrated status")
klog.FromContext(ctx).Error(metricsErr, "Failed to record migrated status")
} else {
cmmBase = cmmv
}
Expand Down
24 changes: 12 additions & 12 deletions connection/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func TestConnect(t *testing.T) {
addr, stopServer := startServer(t, tmp, nil, nil, nil)
defer stopServer()

conn, err := Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if assert.NoError(t, err, "connect via absolute path") &&
assert.NotNil(t, conn, "got a connection") {
assert.Equal(t, connectivity.Ready, conn.GetState(), "connection ready")
Expand All @@ -123,7 +123,7 @@ func TestConnectUnix(t *testing.T) {
addr, stopServer := startServer(t, tmp, nil, nil, nil)
defer stopServer()

conn, err := Connect("unix:///"+addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := Connect(context.TODO(), "unix:///"+addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if assert.NoError(t, err, "connect with unix:/// prefix") &&
assert.NotNil(t, conn, "got a connection") {
assert.Equal(t, connectivity.Ready, conn.GetState(), "connection ready")
Expand All @@ -139,7 +139,7 @@ func TestConnectWithoutMetrics(t *testing.T) {
defer stopServer()

// With Connect
conn, err := Connect("unix:///"+addr, nil)
conn, err := Connect(context.TODO(), "unix:///"+addr, nil)
if assert.NoError(t, err, "connect with unix:/// prefix") &&
assert.NotNil(t, conn, "got a connection") {
assert.Equal(t, connectivity.Ready, conn.GetState(), "connection ready")
Expand All @@ -148,7 +148,7 @@ func TestConnectWithoutMetrics(t *testing.T) {
}

// With ConnectWithoutMetics
conn, err = ConnectWithoutMetrics("unix:///" + addr)
conn, err = ConnectWithoutMetrics(context.TODO(), "unix:///"+addr)
if assert.NoError(t, err, "connect with unix:/// prefix") &&
assert.NotNil(t, conn, "got a connection") {
assert.Equal(t, connectivity.Ready, conn.GetState(), "connection ready")
Expand All @@ -163,7 +163,7 @@ func TestConnectWithOtelTracing(t *testing.T) {
addr, stopServer := startServer(t, tmp, nil, nil, nil)
defer stopServer()

conn, err := Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"), WithOtelTracing())
conn, err := Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"), WithOtelTracing())
if assert.NoError(t, err, "connect via absolute path") &&
assert.NotNil(t, conn, "got a connection") {
assert.Equal(t, connectivity.Ready, conn.GetState(), "connection ready")
Expand Down Expand Up @@ -203,7 +203,7 @@ func TestWaitForServer(t *testing.T) {
startTimeServer = time.Now()
_, stopServer = startServer(t, tmp, nil, nil, nil)
}()
conn, err := Connect(path.Join(tmp, serverSock), metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := Connect(context.TODO(), path.Join(tmp, serverSock), metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if assert.NoError(t, err, "connect via absolute path") {
endTime := time.Now()
assert.NotNil(t, conn, "got a connection")
Expand All @@ -222,7 +222,7 @@ func TestTimeout(t *testing.T) {

startTime := time.Now()
timeout := 5 * time.Second
conn, err := connect(path.Join(tmp, "no-such.sock"), []Option{WithTimeout(timeout)})
conn, err := connect(context.TODO(), path.Join(tmp, "no-such.sock"), []Option{WithTimeout(timeout)})
endTime := time.Now()
if assert.Error(t, err, "connection should fail") {
assert.InEpsilon(t, timeout, endTime.Sub(startTime), 1, "connection timeout")
Expand All @@ -241,7 +241,7 @@ func TestReconnect(t *testing.T) {
}()

// Allow reconnection (the default).
conn, err := Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if assert.NoError(t, err, "connect via absolute path") &&
assert.NotNil(t, conn, "got a connection") {
defer conn.Close()
Expand Down Expand Up @@ -286,7 +286,7 @@ func TestDisconnect(t *testing.T) {
}()

reconnectCount := 0
conn, err := Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"), OnConnectionLoss(func() bool {
conn, err := Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"), OnConnectionLoss(func() bool {
reconnectCount++
// Don't reconnect.
return false
Expand Down Expand Up @@ -337,7 +337,7 @@ func TestExplicitReconnect(t *testing.T) {
}()

reconnectCount := 0
conn, err := Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"), OnConnectionLoss(func() bool {
conn, err := Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"), OnConnectionLoss(func() bool {
reconnectCount++
// Reconnect.
return true
Expand Down Expand Up @@ -451,7 +451,7 @@ func TestConnectMetrics(t *testing.T) {
defer stopServer()

cmm := test.cmm
conn, err := Connect(addr, cmm)
conn, err := Connect(context.TODO(), addr, cmm)
if assert.NoError(t, err, "connect via absolute path") &&
assert.NotNil(t, conn, "got a connection") {
defer conn.Close()
Expand Down Expand Up @@ -526,7 +526,7 @@ func TestConnectWithOtelGrpcInterceptorTraces(t *testing.T) {
addr, stopServer := startServer(t, tmp, &identityServer{}, nil, nil)
defer stopServer()

conn, err := Connect(addr, nil, WithOtelTracing())
conn, err := Connect(context.TODO(), addr, nil, WithOtelTracing())

if assert.NoError(t, err, "connect via absolute path") &&
assert.NotNil(t, conn, "got a connection") {
Expand Down
12 changes: 7 additions & 5 deletions deprecatedflags/deprecatedflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
package deprecatedflags

import (
"context"
"flag"
"fmt"

Expand All @@ -30,19 +31,20 @@ import (
// and instead using the option triggers a deprecation warning. The
// return code can be safely ignored and is only provided to support
// replacing functions like flag.Int in a global variable section.
func Add(name string) bool {
flag.Var(deprecated{name: name}, name, "This option is deprecated.")
func Add(ctx context.Context, name string) bool {
flag.Var(deprecated{ctx: ctx, name: name}, name, "This option is deprecated.")
return true
}

// AddBool defines a deprecated boolean option. Otherwise it behaves
// like Add.
func AddBool(name string) bool {
flag.Var(deprecated{name: name, isBool: true}, name, "This option is deprecated.")
func AddBool(ctx context.Context, name string) bool {
flag.Var(deprecated{ctx: ctx, name: name, isBool: true}, name, "This option is deprecated.")
return true
}

type deprecated struct {
ctx context.Context
name string
isBool bool
}
Expand All @@ -51,7 +53,7 @@ var _ flag.Value = deprecated{}

func (d deprecated) String() string { return "" }
func (d deprecated) Set(value string) error {
klog.Background().Info("Warning: this option is deprecated and has no effect", "option", fmt.Sprintf("%s=%q", d.name, value))
klog.FromContext(d.ctx).Info("Warning: this option is deprecated and has no effect", "option", fmt.Sprintf("%s=%q", d.name, value))
return nil
}
func (d deprecated) Type() string { return "" }
Expand Down
9 changes: 4 additions & 5 deletions leaderelection/leader_election.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"strings"
"time"

"k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/kubernetes/scheme"
corev1 "k8s.io/client-go/kubernetes/typed/core/v1"
Expand Down Expand Up @@ -167,23 +167,22 @@ func (l *leaderElection) Run() error {
return err
}

logger := klog.FromContext(l.ctx)
leaderConfig := leaderelection.LeaderElectionConfig{
Lock: lock,
LeaseDuration: l.leaseDuration,
RenewDeadline: l.renewDeadline,
RetryPeriod: l.retryPeriod,
Callbacks: leaderelection.LeaderCallbacks{
OnStartedLeading: func(ctx context.Context) {
logger.V(2).Info("became leader, starting")
klog.FromContext(l.ctx).V(2).Info("became leader, starting")
l.runFunc(ctx)
},
OnStoppedLeading: func() {
logger.Error(nil, "Stopped leading")
klog.FromContext(l.ctx).Error(nil, "Stopped leading")
klog.FlushAndExit(klog.ExitFlushTimeout, 1)
},
OnNewLeader: func(identity string) {
logger.V(3).Info("New leader detected", "leader", identity)
klog.FromContext(l.ctx).V(3).Info("New leader detected", "leader", identity)
},
},
WatchDog: l.healthCheck,
Expand Down
8 changes: 4 additions & 4 deletions rpc/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ func GetGroupControllerCapabilities(ctx context.Context, conn *grpc.ClientConn)

// ProbeForever calls Probe() of a CSI driver and waits until the driver becomes ready.
// Any error other than timeout is returned.
func ProbeForever(conn *grpc.ClientConn, singleProbeTimeout time.Duration) error {
func ProbeForever(ctx context.Context, conn *grpc.ClientConn, singleProbeTimeout time.Duration) error {
for {
klog.Background().Info("Probing CSI driver for readiness")
klog.FromContext(ctx).Info("Probing CSI driver for readiness")
ready, err := probeOnce(conn, singleProbeTimeout)
if err != nil {
st, ok := status.FromError(err)
Expand All @@ -148,12 +148,12 @@ func ProbeForever(conn *grpc.ClientConn, singleProbeTimeout time.Duration) error
return fmt.Errorf("CSI driver probe failed: %s", err)
}
// Timeout -> driver is not ready. Fall through to sleep() below.
klog.Background().Info("CSI driver probe timed out")
klog.FromContext(ctx).Info("CSI driver probe timed out")
} else {
if ready {
return nil
}
klog.Background().Info("CSI driver is not ready")
klog.FromContext(ctx).Info("CSI driver is not ready")
}
// Timeout was returned or driver is not ready.
time.Sleep(probeInterval)
Expand Down
12 changes: 6 additions & 6 deletions rpc/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func TestGetDriverName(t *testing.T) {
stopServer()
}()

conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := connection.Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if err != nil {
t.Fatalf("Failed to connect to CSI driver: %s", err)
}
Expand Down Expand Up @@ -254,7 +254,7 @@ func TestGetPluginCapabilities(t *testing.T) {
stopServer()
}()

conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := connection.Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if err != nil {
t.Fatalf("Failed to connect to CSI driver: %s", err)
}
Expand Down Expand Up @@ -382,7 +382,7 @@ func TestGetControllerCapabilities(t *testing.T) {
stopServer()
}()

conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := connection.Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if err != nil {
t.Fatalf("Failed to connect to CSI driver: %s", err)
}
Expand Down Expand Up @@ -476,7 +476,7 @@ func TestGetGroupControllerCapabilities(t *testing.T) {
stopServer()
}()

conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := connection.Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if err != nil {
t.Fatalf("Failed to connect to CSI driver: %s", err)
}
Expand Down Expand Up @@ -610,12 +610,12 @@ func TestProbeForever(t *testing.T) {
stopServer()
}()

conn, err := connection.Connect(addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
conn, err := connection.Connect(context.TODO(), addr, metrics.NewCSIMetricsManager("fake.csi.driver.io"))
if err != nil {
t.Fatalf("Failed to connect to CSI driver: %s", err)
}

err = ProbeForever(conn, time.Second)
err = ProbeForever(context.TODO(), conn, time.Second)
if test.expectError && err == nil {
t.Errorf("Expected error, got none")
}
Expand Down

0 comments on commit 1b7a924

Please sign in to comment.