Update to gotest.tools/v3

We have been using a version over 6 years old.

Signed-off-by: Miloslav Trmač <mitr@redhat.com>
This commit is contained in:
Miloslav Trmač 2025-03-14 03:59:17 +01:00
parent 05de8c7758
commit dd32248f47
23 changed files with 960 additions and 442 deletions

View File

@ -14,8 +14,8 @@ import (
"github.com/containers/storage/pkg/system"
"golang.org/x/sys/unix"
"gotest.tools/assert"
is "gotest.tools/assert/cmp"
"gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp"
)
func TestCopy(t *testing.T) {

2
go.mod
View File

@ -31,7 +31,7 @@ require (
github.com/vbatts/tar-split v0.12.1
golang.org/x/sync v0.12.0
golang.org/x/sys v0.31.0
gotest.tools v2.2.0+incompatible
gotest.tools/v3 v3.5.2
)
require (

4
go.sum
View File

@ -198,7 +198,7 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@ -14,7 +14,7 @@ import (
"github.com/containers/storage/pkg/archive"
"golang.org/x/sys/unix"
"gotest.tools/assert"
"gotest.tools/v3/assert"
)
// Test for CVE-2018-15664

View File

@ -11,7 +11,7 @@ import (
"github.com/containers/storage/pkg/unshare"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"gotest.tools/assert"
"gotest.tools/v3/assert"
)
func TestGetRootlessStorageOpts(t *testing.T) {

View File

@ -7,7 +7,7 @@ import (
"testing"
"github.com/containers/storage/pkg/unshare"
"gotest.tools/assert"
"gotest.tools/v3/assert"
)
func TestDefaultStoreOpts(t *testing.T) {

View File

@ -1,311 +0,0 @@
/*Package assert provides assertions for comparing expected values to actual
values. When an assertion fails a helpful error message is printed.
Assert and Check
Assert() and Check() both accept a Comparison, and fail the test when the
comparison fails. The one difference is that Assert() will end the test execution
immediately (using t.FailNow()) whereas Check() will fail the test (using t.Fail()),
return the value of the comparison, then proceed with the rest of the test case.
Example usage
The example below shows assert used with some common types.
import (
"testing"
"gotest.tools/assert"
is "gotest.tools/assert/cmp"
)
func TestEverything(t *testing.T) {
// booleans
assert.Assert(t, ok)
assert.Assert(t, !missing)
// primitives
assert.Equal(t, count, 1)
assert.Equal(t, msg, "the message")
assert.Assert(t, total != 10) // NotEqual
// errors
assert.NilError(t, closer.Close())
assert.Error(t, err, "the exact error message")
assert.ErrorContains(t, err, "includes this")
assert.ErrorType(t, err, os.IsNotExist)
// complex types
assert.DeepEqual(t, result, myStruct{Name: "title"})
assert.Assert(t, is.Len(items, 3))
assert.Assert(t, len(sequence) != 0) // NotEmpty
assert.Assert(t, is.Contains(mapping, "key"))
// pointers and interface
assert.Assert(t, is.Nil(ref))
assert.Assert(t, ref != nil) // NotNil
}
Comparisons
Package https://godoc.org/gotest.tools/assert/cmp provides
many common comparisons. Additional comparisons can be written to compare
values in other ways. See the example Assert (CustomComparison).
Automated migration from testify
gty-migrate-from-testify is a binary which can update source code which uses
testify assertions to use the assertions provided by this package.
See http://bit.do/cmd-gty-migrate-from-testify.
*/
package assert // import "gotest.tools/assert"
import (
"fmt"
"go/ast"
"go/token"
gocmp "github.com/google/go-cmp/cmp"
"gotest.tools/assert/cmp"
"gotest.tools/internal/format"
"gotest.tools/internal/source"
)
// BoolOrComparison can be a bool, or cmp.Comparison. See Assert() for usage.
type BoolOrComparison interface{}
// TestingT is the subset of testing.T used by the assert package.
type TestingT interface {
FailNow()
Fail()
Log(args ...interface{})
}
type helperT interface {
Helper()
}
const failureMessage = "assertion failed: "
// nolint: gocyclo
func assert(
t TestingT,
failer func(),
argSelector argSelector,
comparison BoolOrComparison,
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
var success bool
switch check := comparison.(type) {
case bool:
if check {
return true
}
logFailureFromBool(t, msgAndArgs...)
// Undocumented legacy comparison without Result type
case func() (success bool, message string):
success = runCompareFunc(t, check, msgAndArgs...)
case nil:
return true
case error:
msg := "error is not nil: "
t.Log(format.WithCustomMessage(failureMessage+msg+check.Error(), msgAndArgs...))
case cmp.Comparison:
success = runComparison(t, argSelector, check, msgAndArgs...)
case func() cmp.Result:
success = runComparison(t, argSelector, check, msgAndArgs...)
default:
t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
}
if success {
return true
}
failer()
return false
}
func runCompareFunc(
t TestingT,
f func() (success bool, message string),
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if success, message := f(); !success {
t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
return false
}
return true
}
func logFailureFromBool(t TestingT, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
const stackIndex = 3 // Assert()/Check(), assert(), formatFailureFromBool()
const comparisonArgPos = 1
args, err := source.CallExprArgs(stackIndex)
if err != nil {
t.Log(err.Error())
return
}
msg, err := boolFailureMessage(args[comparisonArgPos])
if err != nil {
t.Log(err.Error())
msg = "expression is false"
}
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
}
func boolFailureMessage(expr ast.Expr) (string, error) {
if binaryExpr, ok := expr.(*ast.BinaryExpr); ok && binaryExpr.Op == token.NEQ {
x, err := source.FormatNode(binaryExpr.X)
if err != nil {
return "", err
}
y, err := source.FormatNode(binaryExpr.Y)
if err != nil {
return "", err
}
return x + " is " + y, nil
}
if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
x, err := source.FormatNode(unaryExpr.X)
if err != nil {
return "", err
}
return x + " is true", nil
}
formatted, err := source.FormatNode(expr)
if err != nil {
return "", err
}
return "expression is false: " + formatted, nil
}
// Assert performs a comparison. If the comparison fails the test is marked as
// failed, a failure message is logged, and execution is stopped immediately.
//
// The comparison argument may be one of three types: bool, cmp.Comparison or
// error.
// When called with a bool the failure message will contain the literal source
// code of the expression.
// When called with a cmp.Comparison the comparison is responsible for producing
// a helpful failure message.
// When called with an error a nil value is considered success. A non-nil error
// is a failure, and Error() is used as the failure message.
func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsFromComparisonCall, comparison, msgAndArgs...)
}
// Check performs a comparison. If the comparison fails the test is marked as
// failed, a failure message is logged, and Check returns false. Otherwise returns
// true.
//
// See Assert for details about the comparison arg and failure messages.
func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
return assert(t, t.Fail, argsFromComparisonCall, comparison, msgAndArgs...)
}
// NilError fails the test immediately if err is not nil.
// This is equivalent to Assert(t, err)
func NilError(t TestingT, err error, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, err, msgAndArgs...)
}
// Equal uses the == operator to assert two values are equal and fails the test
// if they are not equal.
//
// If the comparison fails Equal will use the variable names for x and y as part
// of the failure message to identify the actual and expected values.
//
// If either x or y are a multi-line string the failure message will include a
// unified diff of the two values. If the values only differ by whitespace
// the unified diff will be augmented by replacing whitespace characters with
// visible characters to identify the whitespace difference.
//
// This is equivalent to Assert(t, cmp.Equal(x, y)).
func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.Equal(x, y), msgAndArgs...)
}
// DeepEqual uses google/go-cmp (http://bit.do/go-cmp) to assert two values are
// equal and fails the test if they are not equal.
//
// Package https://godoc.org/gotest.tools/assert/opt provides some additional
// commonly used Options.
//
// This is equivalent to Assert(t, cmp.DeepEqual(x, y)).
func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.DeepEqual(x, y, opts...))
}
// Error fails the test if err is nil, or the error message is not the expected
// message.
// Equivalent to Assert(t, cmp.Error(err, message)).
func Error(t TestingT, err error, message string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.Error(err, message), msgAndArgs...)
}
// ErrorContains fails the test if err is nil, or the error message does not
// contain the expected substring.
// Equivalent to Assert(t, cmp.ErrorContains(err, substring)).
func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.ErrorContains(err, substring), msgAndArgs...)
}
// ErrorType fails the test if err is nil, or err is not the expected type.
//
// Expected can be one of:
// a func(error) bool which returns true if the error is the expected type,
// an instance of (or a pointer to) a struct of the expected type,
// a pointer to an interface the error is expected to implement,
// a reflect.Type of the expected struct or interface.
//
// Equivalent to Assert(t, cmp.ErrorType(err, expected)).
func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
assert(t, t.FailNow, argsAfterT, cmp.ErrorType(err, expected), msgAndArgs...)
}

311
vendor/gotest.tools/v3/assert/assert.go vendored Normal file
View File

@ -0,0 +1,311 @@
/*
Package assert provides assertions for comparing expected values to actual
values in tests. When an assertion fails a helpful error message is printed.
# Example usage
All the assertions in this package use [testing.T.Helper] to mark themselves as
test helpers. This allows the testing package to print the filename and line
number of the file function that failed.
assert.NilError(t, err)
// filename_test.go:212: assertion failed: error is not nil: file not found
If any assertion is called from a helper function, make sure to call t.Helper
from the helper function so that the filename and line number remain correct.
The examples below show assert used with some common types and the failure
messages it produces. The filename and line number portion of the failure
message is omitted from these examples for brevity.
// booleans
assert.Assert(t, ok)
// assertion failed: ok is false
assert.Assert(t, !missing)
// assertion failed: missing is true
// primitives
assert.Equal(t, count, 1)
// assertion failed: 0 (count int) != 1 (int)
assert.Equal(t, msg, "the message")
// assertion failed: my message (msg string) != the message (string)
assert.Assert(t, total != 10) // use Assert for NotEqual
// assertion failed: total is 10
assert.Assert(t, count > 20, "count=%v", count)
// assertion failed: count is <= 20: count=1
// errors
assert.NilError(t, closer.Close())
// assertion failed: error is not nil: close /file: errno 11
assert.Error(t, err, "the exact error message")
// assertion failed: expected error "the exact error message", got "oops"
assert.ErrorContains(t, err, "includes this")
// assertion failed: expected error to contain "includes this", got "oops"
assert.ErrorIs(t, err, os.ErrNotExist)
// assertion failed: error is "oops", not "file does not exist" (os.ErrNotExist)
// complex types
assert.DeepEqual(t, result, myStruct{Name: "title"})
// assertion failed: ... (diff of the two structs)
assert.Assert(t, is.Len(items, 3))
// assertion failed: expected [] (length 0) to have length 3
assert.Assert(t, len(sequence) != 0) // use Assert for NotEmpty
// assertion failed: len(sequence) is 0
assert.Assert(t, is.Contains(mapping, "key"))
// assertion failed: map[other:1] does not contain key
// pointers and interface
assert.Assert(t, ref == nil)
// assertion failed: ref is not nil
assert.Assert(t, ref != nil) // use Assert for NotNil
// assertion failed: ref is nil
# Assert and Check
[Assert] and [Check] are very similar, they both accept a [cmp.Comparison], and fail
the test when the comparison fails. The one difference is that Assert uses
[testing.T.FailNow] to fail the test, which will end the test execution immediately.
Check uses [testing.T.Fail] to fail the test, which allows it to return the
result of the comparison, then proceed with the rest of the test case.
Like [testing.T.FailNow], [Assert] must be called from the goroutine running the test,
not from other goroutines created during the test. [Check] is safe to use from any
goroutine.
# Comparisons
Package [gotest.tools/v3/assert/cmp] provides
many common comparisons. Additional comparisons can be written to compare
values in other ways. See the example Assert (CustomComparison).
# Automated migration from testify
gty-migrate-from-testify is a command which translates Go source code from
testify assertions to the assertions provided by this package.
See http://pkg.go.dev/gotest.tools/v3/assert/cmd/gty-migrate-from-testify.
*/
package assert // import "gotest.tools/v3/assert"
import (
gocmp "github.com/google/go-cmp/cmp"
"gotest.tools/v3/assert/cmp"
"gotest.tools/v3/internal/assert"
)
// BoolOrComparison can be a bool, [cmp.Comparison], or error. See [Assert] for
// details about how this type is used.
type BoolOrComparison interface{}
// TestingT is the subset of [testing.T] (see also [testing.TB]) used by the assert package.
type TestingT interface {
FailNow()
Fail()
Log(args ...interface{})
}
type helperT interface {
Helper()
}
// Assert performs a comparison. If the comparison fails, the test is marked as
// failed, a failure message is logged, and execution is stopped immediately.
//
// The comparison argument may be one of three types:
//
// bool
// True is success. False is a failure. The failure message will contain
// the literal source code of the expression.
//
// cmp.Comparison
// Uses cmp.Result.Success() to check for success or failure.
// The comparison is responsible for producing a helpful failure message.
// http://pkg.go.dev/gotest.tools/v3/assert/cmp provides many common comparisons.
//
// error
// A nil value is considered success, and a non-nil error is a failure.
// The return value of error.Error is used as the failure message.
//
// Extra details can be added to the failure message using msgAndArgs. msgAndArgs
// may be either a single string, or a format string and args that will be
// passed to [fmt.Sprintf].
//
// Assert uses [testing.TB.FailNow] to fail the test. Like t.FailNow, Assert must be called
// from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] from other goroutines.
func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsFromComparisonCall, comparison, msgAndArgs...) {
t.FailNow()
}
}
// Check performs a comparison. If the comparison fails the test is marked as
// failed, a failure message is printed, and Check returns false. If the comparison
// is successful Check returns true. Check may be called from any goroutine.
//
// See [Assert] for details about the comparison arg and failure messages.
func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsFromComparisonCall, comparison, msgAndArgs...) {
t.Fail()
return false
}
return true
}
// NilError fails the test immediately if err is not nil, and includes err.Error
// in the failure message.
//
// NilError uses [testing.TB.FailNow] to fail the test. Like t.FailNow, NilError must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] from other goroutines.
func NilError(t TestingT, err error, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, err, msgAndArgs...) {
t.FailNow()
}
}
// Equal uses the == operator to assert two values are equal and fails the test
// if they are not equal.
//
// If the comparison fails Equal will use the variable names and types of
// x and y as part of the failure message to identify the actual and expected
// values.
//
// assert.Equal(t, actual, expected)
// // main_test.go:41: assertion failed: 1 (actual int) != 21 (expected int32)
//
// If either x or y are a multi-line string the failure message will include a
// unified diff of the two values. If the values only differ by whitespace
// the unified diff will be augmented by replacing whitespace characters with
// visible characters to identify the whitespace difference.
//
// Equal uses [testing.T.FailNow] to fail the test. Like t.FailNow, Equal must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.Equal] from other
// goroutines.
func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.Equal(x, y), msgAndArgs...) {
t.FailNow()
}
}
// DeepEqual uses [github.com/google/go-cmp/cmp]
// to assert two values are equal and fails the test if they are not equal.
//
// Package [gotest.tools/v3/assert/opt] provides some additional
// commonly used Options.
//
// DeepEqual uses [testing.T.FailNow] to fail the test. Like t.FailNow, DeepEqual must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.DeepEqual] from other
// goroutines.
func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.DeepEqual(x, y, opts...)) {
t.FailNow()
}
}
// Error fails the test if err is nil, or if err.Error is not equal to expected.
// Both err.Error and expected will be included in the failure message.
// Error performs an exact match of the error text. Use [ErrorContains] if only
// part of the error message is relevant. Use [ErrorType] or [ErrorIs] to compare
// errors by type.
//
// Error uses [testing.T.FailNow] to fail the test. Like t.FailNow, Error must be
// called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.Error] from other
// goroutines.
func Error(t TestingT, err error, expected string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.Error(err, expected), msgAndArgs...) {
t.FailNow()
}
}
// ErrorContains fails the test if err is nil, or if err.Error does not
// contain the expected substring. Both err.Error and the expected substring
// will be included in the failure message.
//
// ErrorContains uses [testing.T.FailNow] to fail the test. Like t.FailNow, ErrorContains
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.ErrorContains] from other
// goroutines.
func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorContains(err, substring), msgAndArgs...) {
t.FailNow()
}
}
// ErrorType fails the test if err is nil, or err is not the expected type.
// New code should use ErrorIs instead.
//
// Expected can be one of:
//
// func(error) bool
// The function should return true if the error is the expected type.
//
// struct{} or *struct{}
// A struct or a pointer to a struct. The assertion fails if the error is
// not of the same type.
//
// *interface{}
// A pointer to an interface type. The assertion fails if err does not
// implement the interface.
//
// reflect.Type
// The assertion fails if err does not implement the reflect.Type.
//
// ErrorType uses [testing.T.FailNow] to fail the test. Like t.FailNow, ErrorType
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.ErrorType] from other
// goroutines.
func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorType(err, expected), msgAndArgs...) {
t.FailNow()
}
}
// ErrorIs fails the test if err is nil, or the error does not match expected
// when compared using errors.Is. See [errors.Is] for
// accepted arguments.
//
// ErrorIs uses [testing.T.FailNow] to fail the test. Like t.FailNow, ErrorIs
// must be called from the goroutine running the test function, not from other
// goroutines created during the test. Use [Check] with [cmp.ErrorIs] from other
// goroutines.
func ErrorIs(t TestingT, err error, expected error, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if !assert.Eval(t, assert.ArgsAfterT, cmp.ErrorIs(err, expected), msgAndArgs...) {
t.FailNow()
}
}

View File

@ -1,26 +1,27 @@
/*Package cmp provides Comparisons for Assert and Check*/
package cmp // import "gotest.tools/assert/cmp"
package cmp // import "gotest.tools/v3/assert/cmp"
import (
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"github.com/google/go-cmp/cmp"
"gotest.tools/internal/format"
"gotest.tools/v3/internal/format"
)
// Comparison is a function which compares values and returns ResultSuccess if
// Comparison is a function which compares values and returns [ResultSuccess] if
// the actual value matches the expected value. If the values do not match the
// Result will contain a message about why it failed.
// [Result] will contain a message about why it failed.
type Comparison func() Result
// DeepEqual compares two values using google/go-cmp (http://bit.do/go-cmp)
// DeepEqual compares two values using [github.com/google/go-cmp/cmp]
// and succeeds if the values are equal.
//
// The comparison can be customized using comparison Options.
// Package https://godoc.org/gotest.tools/assert/opt provides some additional
// Package [gotest.tools/v3/assert/opt] provides some additional
// commonly used Options.
func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
return func() (result Result) {
@ -33,7 +34,7 @@ func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
if diff == "" {
return ResultSuccess
}
return multiLineDiffResult(diff)
return multiLineDiffResult(diff, x, y)
}
}
@ -59,16 +60,17 @@ func toResult(success bool, msg string) Result {
return ResultFailure(msg)
}
// RegexOrPattern may be either a *regexp.Regexp or a string that is a valid
// RegexOrPattern may be either a [*regexp.Regexp] or a string that is a valid
// regexp pattern.
type RegexOrPattern interface{}
// Regexp succeeds if value v matches regular expression re.
//
// Example:
// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
// r := regexp.MustCompile("^[0-9a-f]{32}$")
// assert.Assert(t, cmp.Regexp(r, str))
//
// assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
// r := regexp.MustCompile("^[0-9a-f]{32}$")
// assert.Assert(t, cmp.Regexp(r, str))
func Regexp(re RegexOrPattern, v string) Comparison {
match := func(re *regexp.Regexp) Result {
return toResult(
@ -92,7 +94,7 @@ func Regexp(re RegexOrPattern, v string) Comparison {
}
}
// Equal succeeds if x == y. See assert.Equal for full documentation.
// Equal succeeds if x == y. See [gotest.tools/v3/assert.Equal] for full documentation.
func Equal(x, y interface{}) Comparison {
return func() Result {
switch {
@ -100,13 +102,13 @@ func Equal(x, y interface{}) Comparison {
return ResultSuccess
case isMultiLineStringCompare(x, y):
diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
return multiLineDiffResult(diff)
return multiLineDiffResult(diff, x, y)
}
return ResultFailureTemplate(`
{{- .Data.x}} (
{{- printf "%v" .Data.x}} (
{{- with callArg 0 }}{{ formatNode . }} {{end -}}
{{- printf "%T" .Data.x -}}
) != {{ .Data.y}} (
) != {{ printf "%v" .Data.y}} (
{{- with callArg 1 }}{{ formatNode . }} {{end -}}
{{- printf "%T" .Data.y -}}
)`,
@ -126,12 +128,12 @@ func isMultiLineStringCompare(x, y interface{}) bool {
return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
}
func multiLineDiffResult(diff string) Result {
func multiLineDiffResult(diff string, x, y interface{}) Result {
return ResultFailureTemplate(`
--- {{ with callArg 0 }}{{ formatNode . }}{{else}}{{end}}
+++ {{ with callArg 1 }}{{ formatNode . }}{{else}}{{end}}
{{ .Data.diff }}`,
map[string]interface{}{"diff": diff})
map[string]interface{}{"diff": diff, "x": x, "y": y})
}
// Len succeeds if the sequence has the expected length.
@ -156,15 +158,15 @@ func Len(seq interface{}, expected int) Comparison {
// slice, or array.
//
// If collection is a string, item must also be a string, and is compared using
// strings.Contains().
// [strings.Contains].
// If collection is a Map, contains will succeed if item is a key in the map.
// If collection is a slice or array, item is compared to each item in the
// sequence using reflect.DeepEqual().
// sequence using [reflect.DeepEqual].
func Contains(collection interface{}, item interface{}) Comparison {
return func() Result {
colValue := reflect.ValueOf(collection)
if !colValue.IsValid() {
return ResultFailure(fmt.Sprintf("nil does not contain items"))
return ResultFailure("nil does not contain items")
}
msg := fmt.Sprintf("%v does not contain %v", collection, item)
@ -241,10 +243,13 @@ func ErrorContains(err error, substring string) Comparison {
}
}
type causer interface {
Cause() error
}
func formatErrorMessage(err error) string {
if _, ok := err.(interface {
Cause() error
}); ok {
//nolint:errorlint,nolintlint // unwrapping is not appropriate here
if _, ok := err.(causer); ok {
return fmt.Sprintf("%q\n%+v", err, err)
}
// This error was not wrapped with github.com/pkg/errors
@ -253,7 +258,7 @@ func formatErrorMessage(err error) string {
// Nil succeeds if obj is a nil interface, pointer, or function.
//
// Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices,
// Use [gotest.tools/v3/assert.NilError] for comparing errors. Use Len(obj, 0) for comparing slices,
// maps, and channels.
func Nil(obj interface{}) Comparison {
msgFunc := func(value reflect.Value) string {
@ -281,12 +286,27 @@ func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
}
// ErrorType succeeds if err is not nil and is of the expected type.
// New code should use [ErrorIs] instead.
//
// Expected can be one of:
// a func(error) bool which returns true if the error is the expected type,
// an instance of (or a pointer to) a struct of the expected type,
// a pointer to an interface the error is expected to implement,
// a reflect.Type of the expected struct or interface.
//
// func(error) bool
//
// Function should return true if the error is the expected type.
//
// type struct{}, type &struct{}
//
// A struct or a pointer to a struct.
// Fails if the error is not of the same type as expected.
//
// type &interface{}
//
// A pointer to an interface type.
// Fails if err does not implement the interface.
//
// reflect.Type
//
// Fails if err does not implement the [reflect.Type].
func ErrorType(err error, expected interface{}) Comparison {
return func() Result {
switch expectedType := expected.(type) {
@ -298,7 +318,7 @@ func ErrorType(err error, expected interface{}) Comparison {
}
return cmpErrorTypeEqualType(err, expectedType)
case nil:
return ResultFailure(fmt.Sprintf("invalid type for expected: nil"))
return ResultFailure("invalid type for expected: nil")
}
expectedType := reflect.TypeOf(expected)
@ -354,3 +374,30 @@ func isPtrToInterface(typ reflect.Type) bool {
func isPtrToStruct(typ reflect.Type) bool {
return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
}
var (
stdlibErrorNewType = reflect.TypeOf(errors.New(""))
stdlibFmtErrorType = reflect.TypeOf(fmt.Errorf("%w", fmt.Errorf("")))
)
// ErrorIs succeeds if errors.Is(actual, expected) returns true. See
// [errors.Is] for accepted argument values.
func ErrorIs(actual error, expected error) Comparison {
return func() Result {
if errors.Is(actual, expected) {
return ResultSuccess
}
// The type of stdlib errors is excluded because the type is not relevant
// in those cases. The type is only important when it is a user defined
// custom error type.
return ResultFailureTemplate(`error is
{{- if not .Data.a }} nil,{{ else }}
{{- printf " \"%v\"" .Data.a }}
{{- if notStdlibErrorType .Data.a }} ({{ printf "%T" .Data.a }}){{ end }},
{{- end }} not {{ printf "\"%v\"" .Data.x }} (
{{- with callArg 1 }}{{ formatNode . }}{{ end }}
{{- if notStdlibErrorType .Data.x }}{{ printf " %T" .Data.x }}{{ end }})`,
map[string]interface{}{"a": actual, "x": expected})
}
}

View File

@ -4,39 +4,46 @@ import (
"bytes"
"fmt"
"go/ast"
"reflect"
"text/template"
"gotest.tools/internal/source"
"gotest.tools/v3/internal/source"
)
// Result of a Comparison.
// A Result of a [Comparison].
type Result interface {
Success() bool
}
type result struct {
// StringResult is an implementation of [Result] that reports the error message
// string verbatim and does not provide any templating or formatting of the
// message.
type StringResult struct {
success bool
message string
}
func (r result) Success() bool {
// Success returns true if the comparison was successful.
func (r StringResult) Success() bool {
return r.success
}
func (r result) FailureMessage() string {
// FailureMessage returns the message used to provide additional information
// about the failure.
func (r StringResult) FailureMessage() string {
return r.message
}
// ResultSuccess is a constant which is returned by a ComparisonWithResult to
// ResultSuccess is a constant which is returned by a [Comparison] to
// indicate success.
var ResultSuccess = result{success: true}
var ResultSuccess = StringResult{success: true}
// ResultFailure returns a failed Result with a failure message.
func ResultFailure(message string) Result {
return result{message: message}
// ResultFailure returns a failed [Result] with a failure message.
func ResultFailure(message string) StringResult {
return StringResult{message: message}
}
// ResultFromError returns ResultSuccess if err is nil. Otherwise ResultFailure
// ResultFromError returns [ResultSuccess] if err is nil. Otherwise [ResultFailure]
// is returned with the error message as the failure message.
func ResultFromError(err error) Result {
if err == nil {
@ -46,13 +53,12 @@ func ResultFromError(err error) Result {
}
type templatedResult struct {
success bool
template string
data map[string]interface{}
}
func (r templatedResult) Success() bool {
return r.success
return false
}
func (r templatedResult) FailureMessage(args []ast.Expr) string {
@ -63,7 +69,12 @@ func (r templatedResult) FailureMessage(args []ast.Expr) string {
return msg
}
// ResultFailureTemplate returns a Result with a template string and data which
func (r templatedResult) UpdatedExpected(stackIndex int) error {
// TODO: would be nice to have structured data instead of a map
return source.UpdateExpectedValue(stackIndex+1, r.data["x"], r.data["y"])
}
// ResultFailureTemplate returns a [Result] with a template string and data which
// can be used to format a failure message. The template may access data from .Data,
// the comparison args with the callArg function, and the formatNode function may
// be used to format the call args.
@ -80,6 +91,11 @@ func renderMessage(result templatedResult, args []ast.Expr) (string, error) {
}
return args[index]
},
// TODO: any way to include this from ErrorIS instead of here?
"notStdlibErrorType": func(typ interface{}) bool {
r := reflect.TypeOf(typ)
return r != stdlibFmtErrorType && r != stdlibErrorNewType
},
})
var err error
tmpl, err = tmpl.Parse(result.template)

View File

@ -0,0 +1,160 @@
// Package assert provides internal utilties for assertions.
package assert
import (
"fmt"
"go/ast"
"go/token"
"reflect"
"gotest.tools/v3/assert/cmp"
"gotest.tools/v3/internal/format"
"gotest.tools/v3/internal/source"
)
// LogT is the subset of testing.T used by the assert package.
type LogT interface {
Log(args ...interface{})
}
type helperT interface {
Helper()
}
const failureMessage = "assertion failed: "
// Eval the comparison and print a failure messages if the comparison has failed.
func Eval(
t LogT,
argSelector argSelector,
comparison interface{},
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
var success bool
switch check := comparison.(type) {
case bool:
if check {
return true
}
logFailureFromBool(t, msgAndArgs...)
// Undocumented legacy comparison without Result type
case func() (success bool, message string):
success = runCompareFunc(t, check, msgAndArgs...)
case nil:
return true
case error:
msg := failureMsgFromError(check)
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
case cmp.Comparison:
success = RunComparison(t, argSelector, check, msgAndArgs...)
case func() cmp.Result:
success = RunComparison(t, argSelector, check, msgAndArgs...)
default:
t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
}
return success
}
func runCompareFunc(
t LogT,
f func() (success bool, message string),
msgAndArgs ...interface{},
) bool {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
if success, message := f(); !success {
t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
return false
}
return true
}
func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
if ht, ok := t.(helperT); ok {
ht.Helper()
}
const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
args, err := source.CallExprArgs(stackIndex)
if err != nil {
t.Log(err.Error())
}
var msg string
const comparisonArgIndex = 1 // Assert(t, comparison)
if len(args) <= comparisonArgIndex {
msg = "but assert failed to find the expression to print"
} else {
msg, err = boolFailureMessage(args[comparisonArgIndex])
if err != nil {
t.Log(err.Error())
msg = "expression is false"
}
}
t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
}
func failureMsgFromError(err error) string {
// Handle errors with non-nil types
v := reflect.ValueOf(err)
if v.Kind() == reflect.Ptr && v.IsNil() {
return fmt.Sprintf("error is not nil: error has type %T", err)
}
return "error is not nil: " + err.Error()
}
func boolFailureMessage(expr ast.Expr) (string, error) {
if binaryExpr, ok := expr.(*ast.BinaryExpr); ok {
x, err := source.FormatNode(binaryExpr.X)
if err != nil {
return "", err
}
y, err := source.FormatNode(binaryExpr.Y)
if err != nil {
return "", err
}
switch binaryExpr.Op {
case token.NEQ:
return x + " is " + y, nil
case token.EQL:
return x + " is not " + y, nil
case token.GTR:
return x + " is <= " + y, nil
case token.LSS:
return x + " is >= " + y, nil
case token.GEQ:
return x + " is less than " + y, nil
case token.LEQ:
return x + " is greater than " + y, nil
}
}
if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
x, err := source.FormatNode(unaryExpr.X)
if err != nil {
return "", err
}
return x + " is true", nil
}
if ident, ok := expr.(*ast.Ident); ok {
return ident.Name + " is false", nil
}
formatted, err := source.FormatNode(expr)
if err != nil {
return "", err
}
return "expression is false: " + formatted, nil
}

View File

@ -1,16 +1,19 @@
package assert
import (
"errors"
"fmt"
"go/ast"
"gotest.tools/assert/cmp"
"gotest.tools/internal/format"
"gotest.tools/internal/source"
"gotest.tools/v3/assert/cmp"
"gotest.tools/v3/internal/format"
"gotest.tools/v3/internal/source"
)
func runComparison(
t TestingT,
// RunComparison and return Comparison.Success. If the comparison fails a messages
// will be printed using t.Log.
func RunComparison(
t LogT,
argSelector argSelector,
f cmp.Comparison,
msgAndArgs ...interface{},
@ -23,10 +26,26 @@ func runComparison(
return true
}
if source.IsUpdate() {
if updater, ok := result.(updateExpected); ok {
const stackIndex = 3 // Assert/Check, assert, RunComparison
err := updater.UpdatedExpected(stackIndex)
switch {
case err == nil:
return true
case errors.Is(err, source.ErrNotFound):
// do nothing, fallthrough to regular failure message
default:
t.Log("failed to update source", err)
return false
}
}
}
var message string
switch typed := result.(type) {
case resultWithComparisonArgs:
const stackIndex = 3 // Assert/Check, assert, runComparison
const stackIndex = 3 // Assert/Check, assert, RunComparison
args, err := source.CallExprArgs(stackIndex)
if err != nil {
t.Log(err.Error())
@ -50,6 +69,10 @@ type resultBasic interface {
FailureMessage() string
}
type updateExpected interface {
UpdatedExpected(stackIndex int) error
}
// filterPrintableExpr filters the ast.Expr slice to only include Expr that are
// easy to read when printed and contain relevant information to an assertion.
//
@ -88,15 +111,20 @@ func isShortPrintableExpr(expr ast.Expr) bool {
type argSelector func([]ast.Expr) []ast.Expr
func argsAfterT(args []ast.Expr) []ast.Expr {
// ArgsAfterT selects args starting at position 1. Used when the caller has a
// testing.T as the first argument, and the args to select should follow it.
func ArgsAfterT(args []ast.Expr) []ast.Expr {
if len(args) < 1 {
return nil
}
return args[1:]
}
func argsFromComparisonCall(args []ast.Expr) []ast.Expr {
if len(args) < 1 {
// ArgsFromComparisonCall selects args from the CallExpression at position 1.
// Used when the caller has a testing.T as the first argument, and the args to
// select are passed to the cmp.Comparison at position 1.
func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
if len(args) <= 1 {
return nil
}
if callExpr, ok := args[1].(*ast.CallExpr); ok {
@ -104,3 +132,15 @@ func argsFromComparisonCall(args []ast.Expr) []ast.Expr {
}
return nil
}
// ArgsAtZeroIndex selects args from the CallExpression at position 1.
// Used when the caller accepts a single cmp.Comparison argument.
func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
if len(args) == 0 {
return nil
}
if callExpr, ok := args[0].(*ast.CallExpr); ok {
return callExpr.Args
}
return nil
}

View File

@ -1,19 +1,20 @@
/*Package difflib is a partial port of Python difflib module.
/*
Package difflib is a partial port of Python difflib module.
Original source: https://github.com/pmezard/go-difflib
This file is trimmed to only the parts used by this repository.
*/
package difflib // import "gotest.tools/internal/difflib"
package difflib // import "gotest.tools/v3/internal/difflib"
func min(a, b int) int {
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func max(a, b int) int {
func maxInt(a, b int) int {
if a > b {
return a
}
@ -170,12 +171,15 @@ func (m *SequenceMatcher) isBJunk(s string) bool {
// If IsJunk is not defined:
//
// Return (i,j,k) such that a[i:i+k] is equal to b[j:j+k], where
// alo <= i <= i+k <= ahi
// blo <= j <= j+k <= bhi
//
// alo <= i <= i+k <= ahi
// blo <= j <= j+k <= bhi
//
// and for all (i',j',k') meeting those conditions,
// k >= k'
// i <= i'
// and if i == i', j <= j'
//
// k >= k'
// i <= i'
// and if i == i', j <= j'
//
// In other words, of all maximal matching blocks, return one that
// starts earliest in a, and of all those maximal matching blocks that
@ -393,12 +397,12 @@ func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
if codes[0].Tag == 'e' {
c := codes[0]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[0] = OpCode{c.Tag, max(i1, i2-n), i2, max(j1, j2-n), j2}
codes[0] = OpCode{c.Tag, maxInt(i1, i2-n), i2, maxInt(j1, j2-n), j2}
}
if codes[len(codes)-1].Tag == 'e' {
c := codes[len(codes)-1]
i1, i2, j1, j2 := c.I1, c.I2, c.J1, c.J2
codes[len(codes)-1] = OpCode{c.Tag, i1, min(i2, i1+n), j1, min(j2, j1+n)}
codes[len(codes)-1] = OpCode{c.Tag, i1, minInt(i2, i1+n), j1, minInt(j2, j1+n)}
}
nn := n + n
groups := [][]OpCode{}
@ -408,11 +412,11 @@ func (m *SequenceMatcher) GetGroupedOpCodes(n int) [][]OpCode {
// End the current group and start a new one whenever
// there is a large range with no changes.
if c.Tag == 'e' && i2-i1 > nn {
group = append(group, OpCode{c.Tag, i1, min(i2, i1+n),
j1, min(j2, j1+n)})
group = append(group, OpCode{c.Tag, i1, minInt(i2, i1+n),
j1, minInt(j2, j1+n)})
groups = append(groups, group)
group = []OpCode{}
i1, j1 = max(i1, i2-n), max(j1, j2-n)
i1, j1 = maxInt(i1, i2-n), maxInt(j1, j2-n)
}
group = append(group, OpCode{c.Tag, i1, i2, j1, j2})
}

View File

@ -1,3 +1,4 @@
// Package format provides utilities for formatting diffs and messages.
package format
import (
@ -6,7 +7,7 @@ import (
"strings"
"unicode"
"gotest.tools/internal/difflib"
"gotest.tools/v3/internal/difflib"
)
const (

View File

@ -1,4 +1,4 @@
package format // import "gotest.tools/internal/format"
package format // import "gotest.tools/v3/internal/format"
import "fmt"

View File

@ -0,0 +1,51 @@
package source
import (
"fmt"
"os"
"path/filepath"
)
// These Bazel env vars are documented here:
// https://bazel.build/reference/test-encyclopedia
// Signifies test executable is being driven by `bazel test`.
//
// Due to Bazel's compilation and sandboxing strategy,
// some care is required to handle resolving the original *.go source file.
var inBazelTest = os.Getenv("BAZEL_TEST") == "1"
// The name of the target being tested (ex: //some_package:some_package_test)
var bazelTestTarget = os.Getenv("TEST_TARGET")
// Absolute path to the base of the runfiles tree
var bazelTestSrcdir = os.Getenv("TEST_SRCDIR")
// The local repository's workspace name (ex: __main__)
var bazelTestWorkspace = os.Getenv("TEST_WORKSPACE")
func bazelSourcePath(filename string) (string, error) {
// Use the env vars to resolve the test source files,
// which must be provided as test data in the respective go_test target.
filename = filepath.Join(bazelTestSrcdir, bazelTestWorkspace, filename)
_, err := os.Stat(filename)
if os.IsNotExist(err) {
return "", fmt.Errorf(bazelMissingSourceMsg, filename, bazelTestTarget)
}
return filename, nil
}
var bazelMissingSourceMsg = `
the test source file does not exist: %s
It appears that you are running this test under Bazel (target: %s).
Check that your test source files are added as test data in your go_test targets.
Example:
go_test(
name = "your_package_test",
srcs = ["your_test.go"],
deps = ["@tools_gotest_v3//assert"],
data = glob(["*_test.go"])
)"
`

View File

@ -1,10 +1,9 @@
package source
import (
"fmt"
"go/ast"
"go/token"
"github.com/pkg/errors"
)
func scanToDeferLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
@ -29,11 +28,11 @@ func guessDefer(node ast.Node) (ast.Node, error) {
defers := collectDefers(node)
switch len(defers) {
case 0:
return nil, errors.New("failed to expression in defer")
return nil, fmt.Errorf("failed to find expression in defer")
case 1:
return defers[0].Call, nil
default:
return nil, errors.Errorf(
return nil, fmt.Errorf(
"ambiguous call expression: multiple (%d) defers in call block",
len(defers))
}

View File

@ -1,22 +1,19 @@
package source // import "gotest.tools/internal/source"
// Package source provides utilities for handling source-code.
package source // import "gotest.tools/v3/internal/source"
import (
"bytes"
"errors"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"github.com/pkg/errors"
)
const baseStackIndex = 1
// FormattedCallExprArg returns the argument from an ast.CallExpr at the
// index in the call stack. The argument is formatted using FormatNode.
func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
@ -33,28 +30,39 @@ func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
// the index in the call stack.
func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
_, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
_, filename, line, ok := runtime.Caller(stackIndex + 1)
if !ok {
return nil, errors.New("failed to get call stack")
}
debug("call stack position: %s:%d", filename, lineNum)
debug("call stack position: %s:%d", filename, line)
node, err := getNodeAtLine(filename, lineNum)
if err != nil {
return nil, err
// Normally, `go` will compile programs with absolute paths in
// the debug metadata. However, in the name of reproducibility,
// Bazel uses a compilation strategy that results in relative paths
// (otherwise, since Bazel uses a random tmp dir for compile and sandboxing,
// the resulting binaries would change across compiles/test runs).
if inBazelTest && !filepath.IsAbs(filename) {
var err error
filename, err = bazelSourcePath(filename)
if err != nil {
return nil, err
}
}
debug("found node: %s", debugFormatNode{node})
return getCallExprArgs(node)
}
func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
fileset := token.NewFileSet()
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
if err != nil {
return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err)
}
expr, err := getCallExprArgs(fileset, astFile, line)
if err != nil {
return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err)
}
return expr, nil
}
func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) {
if node := scanToLine(fileset, astFile, lineNum); node != nil {
return node, nil
}
@ -64,8 +72,7 @@ func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
return node, err
}
}
return nil, errors.Errorf(
"failed to find an expression on line %d in %s", lineNum, filename)
return nil, errors.New("failed to find expression")
}
func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
@ -74,7 +81,7 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
switch {
case node == nil || matchedNode != nil:
return false
case nodePosition(fileset, node).Line == lineNum:
case fileset.Position(node.Pos()).Line == lineNum:
matchedNode = node
return false
}
@ -83,35 +90,18 @@ func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
return matchedNode
}
// In golang 1.9 the line number changed from being the line where the statement
// ended to the line where the statement began.
func nodePosition(fileset *token.FileSet, node ast.Node) token.Position {
if goVersionBefore19 {
return fileset.Position(node.End())
func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) {
node, err := getNodeAtLine(fileset, astFile, line)
if err != nil {
return nil, err
}
return fileset.Position(node.Pos())
}
var goVersionBefore19 = func() bool {
version := runtime.Version()
// not a release version
if !strings.HasPrefix(version, "go") {
return false
}
version = strings.TrimPrefix(version, "go")
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
minor, err := strconv.ParseInt(parts[1], 10, 32)
return err == nil && parts[0] == "1" && minor < 9
}()
debug("found node: %s", debugFormatNode{node})
func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
visitor := &callExprVisitor{}
ast.Walk(visitor, node)
if visitor.expr == nil {
return nil, errors.New("failed to find call expression")
return nil, errors.New("failed to find an expression")
}
debug("callExpr: %s", debugFormatNode{visitor.expr})
return visitor.expr.Args, nil
@ -158,6 +148,9 @@ type debugFormatNode struct {
}
func (n debugFormatNode) String() string {
if n.Node == nil {
return "none"
}
out, err := FormatNode(n.Node)
if err != nil {
return fmt.Sprintf("failed to format %s: %s", n.Node, err)

View File

@ -0,0 +1,171 @@
package source
import (
"bytes"
"errors"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"runtime"
"strings"
)
// IsUpdate is returns true if the -update flag is set. It indicates the user
// running the tests would like to update any golden values.
func IsUpdate() bool {
if Update {
return true
}
return flag.Lookup("update").Value.(flag.Getter).Get().(bool)
}
// Update is a shim for testing, and for compatibility with the old -update-golden
// flag.
var Update bool
func init() {
if f := flag.Lookup("update"); f != nil {
getter, ok := f.Value.(flag.Getter)
msg := "some other package defined an incompatible -update flag, expected a flag.Bool"
if !ok {
panic(msg)
}
if _, ok := getter.Get().(bool); !ok {
panic(msg)
}
return
}
flag.Bool("update", false, "update golden values")
}
// ErrNotFound indicates that UpdateExpectedValue failed to find the
// variable to update, likely because it is not a package level variable.
var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
// UpdateExpectedValue looks for a package-level variable with a name that
// starts with expected in the arguments to the caller. If the variable is
// found, the value of the variable will be updated to value of the other
// argument to the caller.
func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
_, filename, line, ok := runtime.Caller(stackIndex + 1)
if !ok {
return errors.New("failed to get call stack")
}
debug("call stack position: %s:%d", filename, line)
fileset := token.NewFileSet()
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
if err != nil {
return fmt.Errorf("failed to parse source file %s: %w", filename, err)
}
expr, err := getCallExprArgs(fileset, astFile, line)
if err != nil {
return fmt.Errorf("call from %s:%d: %w", filename, line, err)
}
if len(expr) < 3 {
debug("not enough arguments %d: %v",
len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
return ErrNotFound
}
argIndex, ident := getIdentForExpectedValueArg(expr)
if argIndex < 0 || ident == nil {
debug("no arguments started with the word 'expected': %v",
debugFormatNode{Node: &ast.CallExpr{Args: expr}})
return ErrNotFound
}
value := x
if argIndex == 1 {
value = y
}
strValue, ok := value.(string)
if !ok {
debug("value must be type string, got %T", value)
return ErrNotFound
}
return UpdateVariable(filename, fileset, astFile, ident, strValue)
}
// UpdateVariable writes to filename the contents of astFile with the value of
// the variable updated to value.
func UpdateVariable(
filename string,
fileset *token.FileSet,
astFile *ast.File,
ident *ast.Ident,
value string,
) error {
obj := ident.Obj
if obj == nil {
return ErrNotFound
}
if obj.Kind != ast.Con && obj.Kind != ast.Var {
debug("can only update var and const, found %v", obj.Kind)
return ErrNotFound
}
switch decl := obj.Decl.(type) {
case *ast.ValueSpec:
if len(decl.Names) != 1 {
debug("more than one name in ast.ValueSpec")
return ErrNotFound
}
decl.Values[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}
case *ast.AssignStmt:
if len(decl.Lhs) != 1 {
debug("more than one name in ast.AssignStmt")
return ErrNotFound
}
decl.Rhs[0] = &ast.BasicLit{
Kind: token.STRING,
Value: "`" + value + "`",
}
default:
debug("can only update *ast.ValueSpec, found %T", obj.Decl)
return ErrNotFound
}
var buf bytes.Buffer
if err := format.Node(&buf, fileset, astFile); err != nil {
return fmt.Errorf("failed to format file after update: %w", err)
}
fh, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to open file %v: %w", filename, err)
}
if _, err = fh.Write(buf.Bytes()); err != nil {
return fmt.Errorf("failed to write file %v: %w", filename, err)
}
if err := fh.Sync(); err != nil {
return fmt.Errorf("failed to sync file %v: %w", filename, err)
}
return nil
}
func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
for i := 1; i < 3; i++ {
switch e := expr[i].(type) {
case *ast.Ident:
if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
return i, e
}
}
}
return -1, nil
}

View File

@ -0,0 +1,35 @@
package source
import (
"runtime"
"strconv"
"strings"
)
// GoVersionLessThan returns true if runtime.Version() is semantically less than
// version major.minor. Returns false if a release version can not be parsed from
// runtime.Version().
func GoVersionLessThan(major, minor int64) bool {
version := runtime.Version()
// not a release version
if !strings.HasPrefix(version, "go") {
return false
}
version = strings.TrimPrefix(version, "go")
parts := strings.Split(version, ".")
if len(parts) < 2 {
return false
}
rMajor, err := strconv.ParseInt(parts[0], 10, 32)
if err != nil {
return false
}
if rMajor != major {
return rMajor < major
}
rMinor, err := strconv.ParseInt(parts[1], 10, 32)
if err != nil {
return false
}
return rMinor < minor
}

15
vendor/modules.txt vendored
View File

@ -224,10 +224,11 @@ google.golang.org/protobuf/types/known/anypb
# gopkg.in/yaml.v3 v3.0.1
## explicit
gopkg.in/yaml.v3
# gotest.tools v2.2.0+incompatible
## explicit
gotest.tools/assert
gotest.tools/assert/cmp
gotest.tools/internal/difflib
gotest.tools/internal/format
gotest.tools/internal/source
# gotest.tools/v3 v3.5.2
## explicit; go 1.17
gotest.tools/v3/assert
gotest.tools/v3/assert/cmp
gotest.tools/v3/internal/assert
gotest.tools/v3/internal/difflib
gotest.tools/v3/internal/format
gotest.tools/v3/internal/source