mirror of https://github.com/knative/client.git
132 lines
3.5 KiB
Go
132 lines
3.5 KiB
Go
// Copyright © 2019 The Knative Authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package mock
|
|
|
|
import (
|
|
"fmt"
|
|
"reflect"
|
|
"testing"
|
|
|
|
"gotest.tools/assert"
|
|
)
|
|
|
|
// Recorded method call
|
|
type ApiMethodCall struct {
|
|
Args []interface{}
|
|
Result []interface{}
|
|
}
|
|
|
|
// Recorder for recording mock call
|
|
type Recorder struct {
|
|
|
|
// Test object used for asserting
|
|
t *testing.T
|
|
|
|
// List of recorded calls in order
|
|
recordedCalls map[string][]ApiMethodCall
|
|
|
|
// Namespace for client
|
|
namespace string
|
|
}
|
|
|
|
// Add a recorded api call the list of calls
|
|
func (r *Recorder) Add(name string, args []interface{}, result []interface{}) {
|
|
call := ApiMethodCall{args, result}
|
|
calls, ok := r.recordedCalls[name]
|
|
if !ok {
|
|
calls = make([]ApiMethodCall, 0)
|
|
r.recordedCalls[name] = calls
|
|
}
|
|
r.recordedCalls[name] = append(calls, call)
|
|
}
|
|
|
|
// Get the next recorded call
|
|
func (r *Recorder) Shift(name string) (*ApiMethodCall, error) {
|
|
calls := r.recordedCalls[name]
|
|
if len(calls) == 0 {
|
|
return nil, fmt.Errorf("no call to '%s' recorded", name)
|
|
}
|
|
call, calls := calls[0], calls[1:]
|
|
r.recordedCalls[name] = calls
|
|
return &call, nil
|
|
}
|
|
|
|
func NewRecorder(t *testing.T, namespace string) *Recorder {
|
|
return &Recorder{
|
|
t: t,
|
|
recordedCalls: make(map[string][]ApiMethodCall),
|
|
namespace: namespace,
|
|
}
|
|
}
|
|
|
|
func (r *Recorder) Namespace() string {
|
|
return r.namespace
|
|
}
|
|
|
|
// Check if every method has been called
|
|
func (r *Recorder) CheckThatAllRecordedMethodsHaveBeenCalled() {
|
|
for k, v := range r.recordedCalls {
|
|
if len(v) > 0 {
|
|
r.t.Errorf("Recorded method \"%s\" not been called", k)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Verify given arguments against recorded arguments
|
|
func (r *Recorder) VerifyCall(name string, args ...interface{}) *ApiMethodCall {
|
|
call := r.getCall(name)
|
|
callArgs := call.Args
|
|
for i, arg := range args {
|
|
assert.Assert(r.t, len(callArgs) > i, "Internal: Invalid recording: Expected %d args, got %d", len(callArgs), len(args))
|
|
fn := reflect.ValueOf(callArgs[i])
|
|
fnType := fn.Type()
|
|
if fnType.Kind() == reflect.Func {
|
|
if fnType.NumIn() == 2 &&
|
|
// It's an assertion function which takes a Testing as first parameter
|
|
fnType.In(0).AssignableTo(reflect.TypeOf(r.t)) {
|
|
fn.Call([]reflect.Value{reflect.ValueOf(r.t), reflect.ValueOf(arg)})
|
|
} else {
|
|
assert.Assert(r.t, fnType.AssignableTo(reflect.TypeOf(arg)))
|
|
}
|
|
} else {
|
|
assert.DeepEqual(r.t, callArgs[i], arg)
|
|
}
|
|
}
|
|
return call
|
|
}
|
|
|
|
// Get call and verify that it exist
|
|
func (r *Recorder) getCall(name string) *ApiMethodCall {
|
|
call, err := r.Shift(name)
|
|
assert.NilError(r.t, err, "invalid mock setup, missing recording step")
|
|
return call
|
|
}
|
|
|
|
// =====================================================================
|
|
// Helper methods
|
|
|
|
// mock.Any() can be used in recording to not check for the argument
|
|
func Any() func(t *testing.T, a interface{}) {
|
|
return func(t *testing.T, a interface{}) {}
|
|
}
|
|
|
|
// Helper method to cast to an error if given
|
|
func ErrorOrNil(err interface{}) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return err.(error)
|
|
}
|