Skip to content

Commit

Permalink
Provide access to admission.Request in custom validator/defaulter
Browse files Browse the repository at this point in the history
  • Loading branch information
sbueringer authored and k8s-infra-cherrypick-robot committed Jul 6, 2022
1 parent f561596 commit b698f2b
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 3 deletions.
40 changes: 38 additions & 2 deletions pkg/builder/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,11 +539,13 @@ func runTests(admissionReviewVersion string) {
// TestDefaulter.
var _ runtime.Object = &TestDefaulter{}

const testDefaulterKind = "TestDefaulter"

type TestDefaulter struct {
Replica int `json:"replica,omitempty"`
}

var testDefaulterGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: "TestDefaulter"}
var testDefaulterGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: testDefaulterKind}

func (d *TestDefaulter) GetObjectKind() schema.ObjectKind { return d }
func (d *TestDefaulter) DeepCopyObject() runtime.Object {
Expand Down Expand Up @@ -574,11 +576,13 @@ func (d *TestDefaulter) Default() {
// TestValidator.
var _ runtime.Object = &TestValidator{}

const testValidatorKind = "TestValidator"

type TestValidator struct {
Replica int `json:"replica,omitempty"`
}

var testValidatorGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: "TestValidator"}
var testValidatorGVK = schema.GroupVersionKind{Group: "foo.test.org", Version: "v1", Kind: testValidatorKind}

func (v *TestValidator) GetObjectKind() schema.ObjectKind { return v }
func (v *TestValidator) DeepCopyObject() runtime.Object {
Expand Down Expand Up @@ -694,6 +698,14 @@ func (dv *TestDefaultValidator) ValidateDelete() error {
type TestCustomDefaulter struct{}

func (*TestCustomDefaulter) Default(ctx context.Context, obj runtime.Object) error {
req, err := admission.RequestFromContext(ctx)
if err != nil {
return fmt.Errorf("expected admission.Request in ctx: %w", err)
}
if req.Kind.Kind != testDefaulterKind {
return fmt.Errorf("expected Kind TestDefaulter got %q", req.Kind.Kind)
}

d := obj.(*TestDefaulter) //nolint:ifshort
if d.Replica < 2 {
d.Replica = 2
Expand All @@ -708,6 +720,14 @@ var _ admission.CustomDefaulter = &TestCustomDefaulter{}
type TestCustomValidator struct{}

func (*TestCustomValidator) ValidateCreate(ctx context.Context, obj runtime.Object) error {
req, err := admission.RequestFromContext(ctx)
if err != nil {
return fmt.Errorf("expected admission.Request in ctx: %w", err)
}
if req.Kind.Kind != testValidatorKind {
return fmt.Errorf("expected Kind TestValidator got %q", req.Kind.Kind)
}

v := obj.(*TestValidator) //nolint:ifshort
if v.Replica < 0 {
return errors.New("number of replica should be greater than or equal to 0")
Expand All @@ -716,6 +736,14 @@ func (*TestCustomValidator) ValidateCreate(ctx context.Context, obj runtime.Obje
}

func (*TestCustomValidator) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) error {
req, err := admission.RequestFromContext(ctx)
if err != nil {
return fmt.Errorf("expected admission.Request in ctx: %w", err)
}
if req.Kind.Kind != testValidatorKind {
return fmt.Errorf("expected Kind TestValidator got %q", req.Kind.Kind)
}

v := newObj.(*TestValidator)
old := oldObj.(*TestValidator) //nolint:ifshort
if v.Replica < 0 {
Expand All @@ -728,6 +756,14 @@ func (*TestCustomValidator) ValidateUpdate(ctx context.Context, oldObj, newObj r
}

func (*TestCustomValidator) ValidateDelete(ctx context.Context, obj runtime.Object) error {
req, err := admission.RequestFromContext(ctx)
if err != nil {
return fmt.Errorf("expected admission.Request in ctx: %w", err)
}
if req.Kind.Kind != testValidatorKind {
return fmt.Errorf("expected Kind TestValidator got %q", req.Kind.Kind)
}

v := obj.(*TestValidator) //nolint:ifshort
if v.Replica > 0 {
return errors.New("number of replica should be less than or equal to 0 to delete")
Expand Down
3 changes: 2 additions & 1 deletion pkg/webhook/admission/defaulter_custom.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package admission
import (
"context"
"encoding/json"

"errors"
"net/http"

Expand Down Expand Up @@ -61,6 +60,8 @@ func (h *defaulterForType) Handle(ctx context.Context, req Request) Response {
panic("object should never be nil")
}

ctx = NewContextWithRequest(ctx, req)

// Get the object in the request
obj := h.object.DeepCopyObject()
if err := h.decoder.Decode(req, obj); err != nil {
Expand Down
2 changes: 2 additions & 0 deletions pkg/webhook/admission/validator_custom.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ func (h *validatorForType) Handle(ctx context.Context, req Request) Response {
panic("object should never be nil")
}

ctx = NewContextWithRequest(ctx, req)

// Get the object in the request
obj := h.object.DeepCopyObject()

Expand Down
18 changes: 18 additions & 0 deletions pkg/webhook/admission/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,21 @@ func StandaloneWebhook(hook *Webhook, opts StandaloneOptions) (http.Handler, err
}
return metrics.InstrumentedHook(opts.MetricsPath, hook), nil
}

// requestContextKey is how we find the admission.Request in a context.Context.
type requestContextKey struct{}

// RequestFromContext returns an admission.Request from ctx.
func RequestFromContext(ctx context.Context) (Request, error) {
if v, ok := ctx.Value(requestContextKey{}).(Request); ok {
return v, nil
}

return Request{}, errors.New("admission.Request not found in context")
}

// NewContextWithRequest returns a new Context, derived from ctx, which carries the
// provided admission.Request.
func NewContextWithRequest(ctx context.Context, req Request) context.Context {
return context.WithValue(ctx, requestContextKey{}, req)
}
15 changes: 15 additions & 0 deletions pkg/webhook/admission/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,21 @@ var _ = Describe("Admission Webhooks", func() {
})
})

var _ = Describe("Should be able to write/read admission.Request to/from context", func() {
ctx := context.Background()
testRequest := Request{
admissionv1.AdmissionRequest{
UID: "test-uid",
},
}

ctx = NewContextWithRequest(ctx, testRequest)

gotRequest, err := RequestFromContext(ctx)
Expect(err).To(Not(HaveOccurred()))
Expect(gotRequest).To(Equal(testRequest))
})

type stringInjector interface {
InjectString(s string) error
}
Expand Down

0 comments on commit b698f2b

Please sign in to comment.