diff --git a/command/agent/acl_endpoint_test.go b/command/agent/acl_endpoint_test.go index 848ccdca6cf..39175115028 100644 --- a/command/agent/acl_endpoint_test.go +++ b/command/agent/acl_endpoint_test.go @@ -125,31 +125,46 @@ func TestHTTP_ACLPolicyCreate(t *testing.T) { p1 := mock.ACLPolicy() buf := encodeReq(p1) req, err := http.NewRequest("PUT", "/v1/acl/policy/"+p1.Name, buf) - if err != nil { - t.Fatalf("err: %v", err) - } + must.NoError(t, err) + respW := httptest.NewRecorder() setToken(req, s.RootToken) // Make the request obj, err := s.Server.ACLPolicySpecificRequest(respW, req) - assert.Nil(t, err) - assert.Nil(t, obj) + must.NoError(t, err) + must.Nil(t, obj) // Check for the index - if respW.Result().Header.Get("X-Nomad-Index") == "" { - t.Fatalf("missing index") - } + must.StrNotEqFold(t, "", respW.Result().Header.Get("X-Nomad-Index")) // Check policy was created state := s.Agent.server.State() out, err := state.ACLPolicyByName(nil, p1.Name) - assert.Nil(t, err) - assert.NotNil(t, out) + must.NoError(t, err) + must.NotNil(t, out) p1.CreateIndex, p1.ModifyIndex = out.CreateIndex, out.ModifyIndex - assert.Equal(t, p1.Name, out.Name) - assert.Equal(t, p1, out) + must.Eq(t, p1.Name, out.Name) + must.Eq(t, p1, out) + + // Create a policy that is invalid. This ensures we call the validation + // func in the RPC handler, also that the correct code and error is + // returned. + aclPolicy2 := mock.ACLPolicy() + aclPolicy2.Rules = "invalid" + + aclPolicy2Req, err := http.NewRequest(http.MethodPut, "/v1/acl/policy/"+aclPolicy2.Name, encodeReq(aclPolicy2)) + must.NoError(t, err) + + respW = httptest.NewRecorder() + setToken(aclPolicy2Req, s.RootToken) + + // Make the request + aclPolicy2Obj, err := s.Server.ACLPolicySpecificRequest(respW, aclPolicy2Req) + must.ErrorContains(t, err, "400") + must.ErrorContains(t, err, "failed to parse rules") + must.Nil(t, aclPolicy2Obj) }) } diff --git a/nomad/acl_endpoint.go b/nomad/acl_endpoint.go index 15a9a7c3990..3df43b4d73e 100644 --- a/nomad/acl_endpoint.go +++ b/nomad/acl_endpoint.go @@ -100,7 +100,7 @@ func (a *ACL) UpsertPolicies(args *structs.ACLPolicyUpsertRequest, reply *struct // Validate each policy, compute hash for idx, policy := range args.Policies { if err := policy.Validate(); err != nil { - return structs.NewErrRPCCodedf(404, "policy %d invalid: %v", idx, err) + return structs.NewErrRPCCodedf(http.StatusBadRequest, "policy %d invalid: %v", idx, err) } policy.SetHash() } diff --git a/nomad/acl_endpoint_test.go b/nomad/acl_endpoint_test.go index fd3baf59090..9233f94173b 100644 --- a/nomad/acl_endpoint_test.go +++ b/nomad/acl_endpoint_test.go @@ -5,7 +5,6 @@ import ( "io/ioutil" "net/url" "path/filepath" - "strings" "testing" "time" @@ -722,10 +721,8 @@ func TestACLEndpoint_UpsertPolicies_Invalid(t *testing.T) { } var resp structs.GenericResponse err := msgpackrpc.CallWithCodec(codec, "ACL.UpsertPolicies", req, &resp) - assert.NotNil(t, err) - if !strings.Contains(err.Error(), "failed to parse") { - t.Fatalf("bad: %s", err) - } + must.ErrorContains(t, err, "400") + must.ErrorContains(t, err, "failed to parse") } func TestACLEndpoint_GetToken(t *testing.T) {