diff --git a/pkg/endpoints/handlers/patch.go b/pkg/endpoints/handlers/patch.go index b5d2192e3..588cd7917 100644 --- a/pkg/endpoints/handlers/patch.go +++ b/pkg/endpoints/handlers/patch.go @@ -186,7 +186,7 @@ func patchResource( case types.JSONPatchType, types.MergePatchType: originalJS, patchedJS, err := patchObjectJSON(patchType, codec, currentObject, patchJS, objToUpdate, versionedObj) if err != nil { - return nil, err + return nil, interpretPatchError(err) } originalObjJS, originalPatchedObjJS = originalJS, patchedJS @@ -198,13 +198,13 @@ func patchResource( // Compute once originalPatchBytes, err = strategicpatch.CreateTwoWayMergePatch(originalObjJS, originalPatchedObjJS, versionedObj) if err != nil { - return nil, err + return nil, interpretPatchError(err) } } // Return a fresh map every time originalPatchMap := make(map[string]interface{}) if err := json.Unmarshal(originalPatchBytes, &originalPatchMap); err != nil { - return nil, err + return nil, errors.NewBadRequest(err.Error()) } return originalPatchMap, nil } @@ -242,7 +242,7 @@ func patchResource( getOriginalPatchMap = func() (map[string]interface{}, error) { patchMap := make(map[string]interface{}) if err := json.Unmarshal(patchJS, &patchMap); err != nil { - return nil, err + return nil, errors.NewBadRequest(err.Error()) } return patchMap, nil } @@ -277,7 +277,7 @@ func patchResource( var err error currentPatchMap, err = strategicpatch.CreateTwoWayMergeMapPatch(originalObjMap, currentObjMap, versionedObj) if err != nil { - return nil, err + return nil, interpretPatchError(err) } } else { // Compute current patch. @@ -287,11 +287,11 @@ func patchResource( } currentPatch, err := strategicpatch.CreateTwoWayMergePatch(originalObjJS, currentObjJS, versionedObj) if err != nil { - return nil, err + return nil, interpretPatchError(err) } currentPatchMap = make(map[string]interface{}) if err := json.Unmarshal(currentPatch, ¤tPatchMap); err != nil { - return nil, err + return nil, errors.NewBadRequest(err.Error()) } } @@ -422,7 +422,7 @@ func strategicPatchObject( patchMap := make(map[string]interface{}) if err := json.Unmarshal(patchJS, &patchMap); err != nil { - return err + return errors.NewBadRequest(err.Error()) } if err := applyPatchToObject(codec, defaulter, originalObjMap, patchMap, objToUpdate, versionedObj); err != nil { @@ -456,3 +456,15 @@ func applyPatchToObject( return nil } + +// interpretPatchError interprets the error type and returns an error with appropriate HTTP code. +func interpretPatchError(err error) error { + switch err { + case mergepatch.ErrBadJSONDoc, mergepatch.ErrBadPatchFormatForPrimitiveList, mergepatch.ErrBadPatchFormatForRetainKeys, mergepatch.ErrBadPatchFormatForSetElementOrderList: + return errors.NewBadRequest(err.Error()) + case mergepatch.ErrNoListOfLists, mergepatch.ErrPatchContentNotMatchRetainKeys: + return errors.NewGenericServerResponse(http.StatusUnprocessableEntity, "", schema.GroupResource{}, "", err.Error(), 0, false) + default: + return err + } +} diff --git a/pkg/endpoints/handlers/rest_test.go b/pkg/endpoints/handlers/rest_test.go index a99fdaf6f..5e6ca117a 100644 --- a/pkg/endpoints/handlers/rest_test.go +++ b/pkg/endpoints/handlers/rest_test.go @@ -98,6 +98,29 @@ func TestPatchAnonymousField(t *testing.T) { } } +func TestPatchInvalid(t *testing.T) { + testGV := schema.GroupVersion{Group: "", Version: "v"} + scheme.AddKnownTypes(testGV, &testPatchType{}) + codec := codecs.LegacyCodec(testGV) + defaulter := runtime.ObjectDefaulter(scheme) + + original := &testPatchType{ + TypeMeta: metav1.TypeMeta{Kind: "testPatchType", APIVersion: "v"}, + TestPatchSubType: TestPatchSubType{StringField: "my-value"}, + } + patch := `barbaz` + expectedError := "invalid character 'b' looking for beginning of value" + + actual := &testPatchType{} + err := strategicPatchObject(codec, defaulter, original, []byte(patch), actual, &testPatchType{}) + if apierrors.IsBadRequest(err) == false { + t.Errorf("expected HTTP status: BadRequest, got: %#v", apierrors.ReasonForError(err)) + } + if err.Error() != expectedError { + t.Errorf("expected %#v, got %#v", expectedError, err.Error()) + } +} + type testPatcher struct { t *testing.T